Commit 75740263 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/copygrad-fix-v2' into 'main'

Lmcafee/copygrad fix v2

See merge request ADLR/megatron-lm!299
parents 2387ce01 f597f02e
......@@ -154,6 +154,11 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True
# If we use a contiguous buffer to hold main grads, we need to have
# local DDP.
if args.use_contiguous_buffers_in_ddp:
assert args.DDP_impl == 'local'
if args.dataloader_type is None:
args.dataloader_type = 'single'
......
......@@ -100,10 +100,12 @@ def get_megatron_optimizer(model):
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_ddp,
args.bf16,
grad_scaler)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad)
params_have_main_grad,
args.use_contiguous_buffers_in_ddp)
......@@ -68,7 +68,9 @@ class MegatronOptimizer(ABC):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
params_have_main_grad,
use_contiguous_buffers_in_ddp):
"""Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
......@@ -76,7 +78,11 @@ class MegatronOptimizer(ABC):
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad
self.use_contiguous_buffers_in_ddp = use_contiguous_buffers_in_ddp
if self.use_contiguous_buffers_in_ddp:
assert self.params_have_main_grad, \
"use of contiguous buffer requires that params have main grad"
def get_parameters(self):
params = []
......@@ -187,11 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, bf16, grad_scaler):
params_have_main_grad, use_contiguous_buffers_in_ddp,
bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
params_have_main_grad, use_contiguous_buffers_in_ddp)
self.bf16 = bf16
self.grad_scaler = grad_scaler
......@@ -310,12 +317,26 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
# Safe to deallocate model's grad/main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
model_param.grad = None
if self.params_have_main_grad and \
not self.use_contiguous_buffers_in_ddp:
model_param.main_grad = None
# For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups:
for model_param in model_group:
model_param.grad = model_param.main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_ddp:
model_param.main_grad = None
def _unscale_main_grads_and_check_for_nan(self):
main_grads = []
......@@ -469,11 +490,12 @@ class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
params_have_main_grad,
use_contiguous_buffers_in_ddp):
super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
params_have_main_grad, use_contiguous_buffers_in_ddp)
self._scale = torch.cuda.FloatTensor([1.0])
......@@ -500,6 +522,12 @@ class FP32Optimizer(MegatronOptimizer):
for param in param_group['params']:
param.grad = param.main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_ddp:
param.main_grad = None
# Clip gradients.
grad_norm = None
if self.clip_grad > 0.0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment