__init__.py 1.59 KB
Newer Older
zhe chen's avatar
zhe chen committed
1
2
3
4
5
6
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

zhe chen's avatar
zhe chen committed
7
8
import torch

zhe chen's avatar
zhe chen committed
9
# -*- coding: utf-8 -*-
zhe chen's avatar
zhe chen committed
10
11
from .custom_layer_decay_optimizer_constructor import \
    CustomLayerDecayOptimizerConstructor
zhe chen's avatar
zhe chen committed
12
from .efficient_ffn import EfficientFFN
zhe chen's avatar
zhe chen committed
13

zhe chen's avatar
zhe chen committed
14
__all__ = ['CustomLayerDecayOptimizerConstructor', 'EfficientFFN']
zhe chen's avatar
zhe chen committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

if torch.__version__.startswith('1.11'):

    from mmcv.runner.hooks import HOOKS, Hook
    from mmcv.runner.optimizer.builder import OPTIMIZERS
    from mmdet.utils.util_distribution import ddp_factory  # noqa: F401,F403
    from torch.distributed.optim import ZeroRedundancyOptimizer

    class ZeroAdamW(ZeroRedundancyOptimizer):
        def __init__(self, params, optimizer_class=torch.optim.AdamW, **kwargs):
            super().__init__(params[0]['params'],
                             optimizer_class=optimizer_class,
                             parameters_as_bucket_view=True,
                             **kwargs)
            for i in range(1, len(params)):
                self.add_param_group(params[i])

    OPTIMIZERS.register_module()(ZeroAdamW)


    @HOOKS.register_module()
    class ZeroHook(Hook):
        def __init__(self, interval):
            self.interval = interval

        def after_epoch(self, runner):
            runner.optimizer.consolidate_state_dict(to=0)

        def after_train_iter(self, runner):
            if self.every_n_iters(runner, self.interval):
                runner.optimizer.consolidate_state_dict(to=0)