Commit 06fc51ce authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into checkpoint_util

parents ec561daa 0ed2f6ac
...@@ -242,7 +242,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -242,7 +242,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A. input_size: first dimension of matrix A.
output_size: second dimension of matrix A. output_size: second dimension of matrix A.
bias: If true, add bias bias: If true, add bias
gather_output: If true, call all-gether on output and make Y avaiable gather_output: If true, call all-gather on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set init_method: method to initialize weights. Note that bias is always set
......
This diff is collapsed.
...@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler ...@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules): def get_param_groups(modules,
"""Divide params into with-weight-decay and without-weight-decay groups. no_weight_decay_cond,
Layernorms and baises will have no weight decay but the rest will. scale_lr_cond,
lr_mult):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
""" """
wd_no_scale_lr = []
weight_decay_params = {'params': []} wd_scale_lr = []
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_wd_no_scale_lr = []
no_wd_scale_lr = []
for module in modules: for module in modules:
for module_ in module.modules(): for name, param in module.named_parameters():
if isinstance(module_, LayerNorm): if not param.requires_grad:
no_weight_decay_params['params'].extend( continue
[p for p in list(module_._parameters.values())
if p is not None]) if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else: else:
weight_decay_params['params'].extend( # do not regularize biases nor Norm parameters
[p for n, p in list(module_._parameters.items()) no_wd = name.endswith(".bias") or len(param.shape) == 1
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
wd_scale_lr.append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
no_wd_scale_lr.append(param)
def get_megatron_optimizer(model): param_groups = []
if len(wd_no_scale_lr):
param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
if len(wd_scale_lr):
param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
if len(no_wd_scale_lr):
param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
return param_groups
def get_megatron_optimizer(model,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
args = get_args() args = get_args()
# Base optimizer. # Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model) param_groups = get_param_groups(model,
no_weight_decay_cond,
scale_lr_cond,
lr_mult)
if args.optimizer == 'adam': if args.optimizer == 'adam':
optimizer = Adam(param_groups, optimizer = Adam(param_groups,
lr=args.lr, lr=args.lr,
......
...@@ -142,10 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -142,10 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
if recv_next: if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment