__init__.py 1.91 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch

from .layer_decay_optimizer_constructor import \
    CustomLayerDecayOptimizerConstructor

__all__ = ['CustomLayerDecayOptimizerConstructor',]


from mmcv.runner.hooks import HOOKS, Hook
from mmcv.runner.optimizer.builder import OPTIMIZERS
from torch.distributed.optim import ZeroRedundancyOptimizer


@OPTIMIZERS.register_module()
class ZeroAdamW(ZeroRedundancyOptimizer):
    def __init__(self, params, optimizer_class=torch.optim.AdamW, **kwargs):
        super().__init__(params[0]['params'],
                         optimizer_class=optimizer_class,
                         parameters_as_bucket_view=True,
                         **kwargs)
        for i in range(1, len(params)):
            self.add_param_group(params[i])


@HOOKS.register_module()
class ZeroHook(Hook):
    def __init__(self, interval):
        self.interval = interval

    def after_epoch(self, runner):
        runner.optimizer.consolidate_state_dict(to=0)

    def after_train_iter(self, runner):
        if self.every_n_iters(runner, self.interval):
            runner.optimizer.consolidate_state_dict(to=0)


@HOOKS.register_module()
class ToBFloat16Hook(Hook):

    def before_run(self, runner):
        runner.model.module.backbone.to(torch.bfloat16)
        runner.model.module.decode_head.to(torch.float32)
        try:
            runner.model.module.auxiliary_head.to(torch.float32)
        except:
            pass
        print('hook:', runner.model.module.backbone.dtype)


@HOOKS.register_module()
class ToFloat16Hook(Hook):

    def before_run(self, runner):
        runner.model.module.backbone.to(torch.float16)
        runner.model.module.decode_head.to(torch.float32)
        try:
            runner.model.module.auxiliary_head.to(torch.float32)
        except:
            pass
        try:
            runner.model.module.neck.to(torch.float32)
        except:
            pass
        print('hook:', runner.model.module.backbone.dtype)