Commit 3715b0bc authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'mem_opt_for_contig_buffer' into 'main'

Memory optimizations for contiguous buffers

See merge request ADLR/megatron-lm!310
parents 68797d90 6a0ef5b1
...@@ -148,16 +148,15 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -148,16 +148,15 @@ def parse_args(extra_args_provider=None, defaults={},
print('using {} for parameters ...'.format(args.params_dtype), print('using {} for parameters ...'.format(args.params_dtype),
flush=True) flush=True)
# If we do accumulation and all-reduces in fp32, we need to have # If we do accumulation and all-reduces in fp32, we need to have local DDP
# local DDP and we should set the use-contiguous-buffers-in-ddp. # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
if args.accumulate_allreduce_grads_in_fp32: if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True assert args.use_contiguous_buffers_in_local_ddp
# If we use a contiguous buffer to hold main grads, we need to have # For torch DDP, we do not use contiguous buffer
# local DDP. if args.DDP_impl == 'torch':
if args.use_contiguous_buffers_in_ddp: args.use_contiguous_buffers_in_local_ddp = False
assert args.DDP_impl == 'local'
if args.dataloader_type is None: if args.dataloader_type is None:
args.dataloader_type = 'single' args.dataloader_type = 'single'
...@@ -584,9 +583,10 @@ def _add_distributed_args(parser): ...@@ -584,9 +583,10 @@ def _add_distributed_args(parser):
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
'to use.') 'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true', group.add_argument('--no-contiguous-buffers-in-local-ddp',
help='If set, use contiguous buffer in DDP. Note that ' action='store_false', help='If set, dont use '
'this option only works woth local DDP.' ) 'contiguous buffer in local DDP.',
dest='use_contiguous_buffers_in_local_ddp')
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline', help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline') dest='scatter_gather_tensors_in_pipeline')
......
...@@ -100,7 +100,7 @@ def get_megatron_optimizer(model): ...@@ -100,7 +100,7 @@ def get_megatron_optimizer(model):
args.clip_grad, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_ddp, args.use_contiguous_buffers_in_local_ddp,
args.bf16, args.bf16,
grad_scaler) grad_scaler)
...@@ -108,4 +108,4 @@ def get_megatron_optimizer(model): ...@@ -108,4 +108,4 @@ def get_megatron_optimizer(model):
return FP32Optimizer(optimizer, args.clip_grad, return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_ddp) args.use_contiguous_buffers_in_local_ddp)
...@@ -69,7 +69,7 @@ class MegatronOptimizer(ABC): ...@@ -69,7 +69,7 @@ class MegatronOptimizer(ABC):
def __init__(self, optimizer, clip_grad, def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad, log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
use_contiguous_buffers_in_ddp): use_contiguous_buffers_in_local_ddp):
"""Input optimizer is the base optimizer for example Adam.""" """Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer self.optimizer = optimizer
...@@ -78,9 +78,9 @@ class MegatronOptimizer(ABC): ...@@ -78,9 +78,9 @@ class MegatronOptimizer(ABC):
self.clip_grad = clip_grad self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad self.params_have_main_grad = params_have_main_grad
self.use_contiguous_buffers_in_ddp = use_contiguous_buffers_in_ddp self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
if self.use_contiguous_buffers_in_ddp: if self.use_contiguous_buffers_in_local_ddp:
assert self.params_have_main_grad, \ assert self.params_have_main_grad, \
"use of contiguous buffer requires that params have main grad" "use of contiguous buffer requires that params have main grad"
...@@ -193,12 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -193,12 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
""" """
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler): bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__( super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_ddp) params_have_main_grad, use_contiguous_buffers_in_local_ddp)
self.bf16 = bf16 self.bf16 = bf16
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
...@@ -323,7 +323,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -323,7 +323,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# persist and therefore should not be deallocated.) # persist and therefore should not be deallocated.)
model_param.grad = None model_param.grad = None
if self.params_have_main_grad and \ if self.params_have_main_grad and \
not self.use_contiguous_buffers_in_ddp: not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None model_param.main_grad = None
# For fp32 grads, we need to reset the grads to main grad. # For fp32 grads, we need to reset the grads to main grad.
...@@ -335,7 +335,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -335,7 +335,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# Safe to de-reference model's main_grad after copying. # Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should # (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.) # persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_ddp: if not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None model_param.main_grad = None
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
...@@ -491,11 +491,11 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -491,11 +491,11 @@ class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad, log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
use_contiguous_buffers_in_ddp): use_contiguous_buffers_in_local_ddp):
super(FP32Optimizer, self).__init__( super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_ddp) params_have_main_grad, use_contiguous_buffers_in_local_ddp)
self._scale = torch.cuda.FloatTensor([1.0]) self._scale = torch.cuda.FloatTensor([1.0])
...@@ -525,7 +525,7 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -525,7 +525,7 @@ class FP32Optimizer(MegatronOptimizer):
# Safe to de-reference model's main_grad after copying. # Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should # (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.) # persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_ddp: if not self.use_contiguous_buffers_in_local_ddp:
param.main_grad = None param.main_grad = None
# Clip gradients. # Clip gradients.
......
...@@ -253,7 +253,7 @@ def get_model(model_provider_func): ...@@ -253,7 +253,7 @@ def get_model(model_provider_func):
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
model = [LocalDDP(model_module, model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_ddp) args.use_contiguous_buffers_in_local_ddp)
for model_module in model] for model_module in model]
return model return model
...@@ -351,10 +351,9 @@ def train_step(forward_step_func, data_iterator, ...@@ -351,10 +351,9 @@ def train_step(forward_step_func, data_iterator,
timers = get_timers() timers = get_timers()
# Set grad to zero. # Set grad to zero.
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp: if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
for partition in model: for partition in model:
partition.zero_grad_buffer() partition.zero_grad_buffer()
else:
optimizer.zero_grad() optimizer.zero_grad()
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment