__init__.py 5.38 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
mohammad's avatar
mohammad committed
2
3

from apex.optimizers import FusedAdam as Adam
4
from apex.optimizers import FusedSGD as SGD
mohammad's avatar
mohammad committed
5

mohammad's avatar
mohammad committed
6
7
from megatron import get_args

8
from .distrib_optimizer import DistributedOptimizer
mohammad's avatar
mohammad committed
9
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
10
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
11

mohammad's avatar
mohammad committed
12

13
14
15
16
17
18
19
20
def get_param_groups(modules,
                     no_weight_decay_cond,
                     scale_lr_cond,
                     lr_mult):
    """creates param groups based on weight decay condition (regularized vs non regularized)
       and learning rate scale condition (args.lr vs lr_mult * args.lr)
       scale_lr_cond is used during finetuning where head of the network requires a scaled
       version of the base learning rate. 
mohammad's avatar
mohammad committed
21
    """
22
23
24
25
    wd_no_scale_lr = []
    wd_scale_lr = []
    no_wd_no_scale_lr = []
    no_wd_scale_lr = []
26
    for module in modules:
27
28
29
30
31
32
        for name, param in module.named_parameters():
            if not param.requires_grad:
                continue

            if no_weight_decay_cond is not None:
                no_wd = no_weight_decay_cond(name, param)
33
            else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
34
                # do not regularize biases nor Norm parameters
35
                no_wd = name.endswith(".bias") or len(param.shape) == 1
mohammad's avatar
mohammad committed
36

37
38
39
40
            if scale_lr_cond is not None:
                scale_lr = scale_lr_cond(name, param)
            else:
                scale_lr = False
mohammad's avatar
mohammad committed
41

42
43
44
45
46
47
48
49
            if not no_wd and not scale_lr:
                wd_no_scale_lr.append(param)
            elif not no_wd and scale_lr:
                wd_scale_lr.append(param)
            elif no_wd and not scale_lr:
                no_wd_no_scale_lr.append(param)
            else:
                no_wd_scale_lr.append(param)
mohammad's avatar
mohammad committed
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    param_groups = []
    if len(wd_no_scale_lr):
        param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
    if len(wd_scale_lr):
        param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
    if len(no_wd_no_scale_lr):
        param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
    if len(no_wd_scale_lr):
        param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})

    return param_groups

def get_megatron_optimizer(model,
                           no_weight_decay_cond=None,
                           scale_lr_cond=None,
                           lr_mult=1.0):
mohammad's avatar
mohammad committed
67
68
69
    args = get_args()

    # Base optimizer.
70
71
72
73
74
    param_groups = get_param_groups(model,
                                    no_weight_decay_cond,
                                    scale_lr_cond,
                                    lr_mult)

75
    if args.optimizer == 'adam':
76
77
78
79
80
        optimizer = Adam(param_groups,
                         lr=args.lr,
                         weight_decay=args.weight_decay,
                         betas=(args.adam_beta1, args.adam_beta2),
                         eps=args.adam_eps)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
81
    elif args.optimizer == 'sgd':
82
83
84
85
        optimizer = SGD(param_groups,
                        lr=args.lr,
                        weight_decay=args.weight_decay,
                        momentum=args.sgd_momentum)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
86
87
    else:
        raise Exception('{} optimizer is not supported.'.format(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
88
            args.optimizer))
mohammad's avatar
mohammad committed
89

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
90
91
92
93
94
    # Determine whether the params have main-grad field.
    params_have_main_grad = False
    if args.DDP_impl == 'local':
        params_have_main_grad = True

95
96
97
98
    # Mixed precision optimizer.
    # - Note: both the Float16Optimizer and the DistributedOptimizer inherit
    #   from the MixedPrecisionOptimizer, which manages any optimizer where
    #   the model params and main params are distinct.
99
    if args.fp16 or args.bf16 or args.use_distributed_optimizer:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
100
101
102
103
104
105
106
107

        # Grad scaler:
        #    if loss-scale is provided, instantiate the constant scaler.
        #    if we are using fp16 and loss-scale is not present, use a
        #       dynamic scaler.
        #    otherwise we are running in bf16 with no loss-scale so
        #       leave it as None.
        grad_scaler = None
108

mohammad's avatar
mohammad committed
109
110
111
        # Constant loss scale.
        if args.loss_scale:
            grad_scaler = ConstantGradScaler(args.loss_scale)
112

mohammad's avatar
mohammad committed
113
114
        # Dynamic loss scale.
        else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
115
116
117
118
119
120
121
122
123
            if args.fp16:
                grad_scaler = DynamicGradScaler(
                    initial_scale=args.initial_loss_scale,
                    min_scale=args.min_loss_scale,
                    growth_factor=2.0,
                    backoff_factor=0.5,
                    growth_interval=args.loss_scale_window,
                    hysteresis=args.hysteresis)

mohammad's avatar
mohammad committed
124
        # Megatron optimizer.
125
126
127
128
        opt_ty = DistributedOptimizer \
            if args.use_distributed_optimizer else \
            Float16OptimizerWithFloat16Params
        return opt_ty(optimizer,
129
130
131
132
                      args.clip_grad,
                      args.log_num_zeros_in_grad,
                      params_have_main_grad,
                      args.use_contiguous_buffers_in_local_ddp,
133
                      args.fp16,
134
                      args.bf16,
135
                      args.params_dtype,
136
137
                      grad_scaler,
                      model)
mohammad's avatar
mohammad committed
138
139

    # FP32.
140
141
142
143
144
    return FP32Optimizer(optimizer, args.clip_grad,
                         args.log_num_zeros_in_grad,
                         params_have_main_grad,
                         args.use_contiguous_buffers_in_local_ddp,
                         model)