Commit d07d29df authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/distrib-opt-nodupe' into 'main'

Distributed optimizer

See merge request ADLR/megatron-lm!408
parents 55ba1f7b 74ba3797
# Distributed Optimizer
The motivation for the distributed optimizer is to save memory by distributing the optimizer state evenly across data parallel ranks, versus the current method of replicating the optimizer state across data parallel ranks. As described in https://arxiv.org/abs/1910.02054, this branch specifically implements the following:
- [yes] distribute all 'non-overlapping' optimizer state (i.e., model params already in fp32 are NOT distributed)
- [no] distribute model gradients
- [no] distribute model parameters
Theoretical memory savings vary depending on the combination of the model's param dtype and grad dtype. In the current implementation, the theoretical number of bytes per parameter is (where 'd' is the data parallel size):
| | Non-distributed optim | Distributed optim |
| ------ | ------ | ------ |
| float16 param, float16 grads | 20 | 4 + 16/d |
| float16 param, fp32 grads | 18 | 6 + 12/d |
| fp32 param, fp32 grads | 16 | 8 + 8/d |
The implementation of the distributed optimizer is centered on using the contiguous grad buffer for communicating grads & params between the model state and the optimizer state. The grad buffer at any given moment either holds:
1. all model grads
2. a 1/d size _copy_ of the main grads (before copying to the optimizer state)
3. a 1/d size _copy_ of the main params (after copying from the optimizer state)
4. all model params
5. zeros (or None), between iterations
The grad buffer is used for performing reduce-scatter and all-gather operations, for passing grads & params between the model state and optimizer state. With this implementation, no dynamic buffers are allocated.
The figures below illustrate the grad buffer's sharding scheme, and the key steps of the distributed optimizer's param update:
## Data flow
![Data flow](images/distrib_optimizer/data_flow.png)
## Sharding scheme
![Sharding scheme](images/distrib_optimizer/sharding_scheme.png)
## Key steps
_(note: using illustrations above, and assuming fp16 grads)_
- Backward pass finishes (grad buffer holds 16 fp16 grad elements)
- Call reduce-scatter on each DP rank
- Each DP rank now has 4 elements within the grad buffer that are fully reduced (remaining 12 elements are garbage)
- Each DP rank copies its relevant 4 fp16 grad elements from the grad buffer into 4 fp32 main grad elements (separate buffer, owned by the optimizer); i.e.
- DP rank 0 copies elements [0:4]
- DP rank 1 copies elements [4:8]
- DP rank 2 copies elements [8:12]
- DP rank 3 copies elements [12:16]
- Optimizer.step()
- Each DP rank copies its 4 fp32 main (/optimizer) param elements into the corresponding 4 fp16 elements in the grad buffer
- Call all-gather on each DP rank
- Grad buffer now contains all 16, fully updated, fp16 model param elements
- Copy updated model params from grad buffer into their respective param tensors
- (At this point, grad buffer is ready to be zero'd for the next iteration)
...@@ -187,6 +187,12 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -187,6 +187,12 @@ def parse_args(extra_args_provider=None, defaults={},
'gradient accumulation. Setting gradient_accumulation_fusion ' 'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True) 'to False', flush=True)
# If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
if args.use_distributed_optimizer:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
# For torch DDP, we do not use contiguous buffer # For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
args.use_contiguous_buffers_in_local_ddp = False args.use_contiguous_buffers_in_local_ddp = False
...@@ -765,6 +771,9 @@ def _add_distributed_args(parser): ...@@ -765,6 +771,9 @@ def _add_distributed_args(parser):
'is placed on its own pipeline stage, without any ' 'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only ' 'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)') 'affects the encoder embedding.)')
group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.')
return parser return parser
......
...@@ -81,24 +81,34 @@ def ensure_directory_exists(filename): ...@@ -81,24 +81,34 @@ def ensure_directory_exists(filename):
os.makedirs(dirname) os.makedirs(dirname)
def get_checkpoint_name(checkpoints_path, iteration, def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
release=False): release=False):
"""A unified checkpoint name.""" """A unified checkpoint name."""
if release: if release:
directory = 'release' directory = 'release'
else: else:
directory = 'iter_{:07d}'.format(iteration) directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank. # Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
if mpu.get_pipeline_model_parallel_world_size() == 1: if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(checkpoints_path, directory, common_path = os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format( 'mp_rank_{:02d}'.format(
mpu.get_tensor_model_parallel_rank()), mpu.get_tensor_model_parallel_rank()))
'model_optim_rng.pt') else:
return os.path.join(checkpoints_path, directory, common_path = os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}'.format( 'mp_rank_{:02d}_{:03d}'.format(
mpu.get_tensor_model_parallel_rank(), mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank()), mpu.get_pipeline_model_parallel_rank()))
'model_optim_rng.pt')
if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join(
common_path + "_%03d" % mpu.get_data_parallel_rank(),
"optim.pt")
else:
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
return model_name, optim_name
def get_checkpoint_tracker_filename(checkpoints_path): def get_checkpoint_tracker_filename(checkpoints_path):
...@@ -177,38 +187,64 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -177,38 +187,64 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save)) iteration, args.save))
# collect rng state across data parallel ranks # Collect rng state across data parallel ranks.
rng_state = get_rng_state() rng_state = get_rng_state()
if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0: # Checkpoint file names.
model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)
# Collect args, model, RNG.
model_state_dict = {}
if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} model_state_dict['args'] = args
state_dict['args'] = args model_state_dict['checkpoint_version'] = 3.0
state_dict['checkpoint_version'] = 3.0 model_state_dict['iteration'] = iteration
state_dict['iteration'] = iteration
if len(model) == 1: if len(model) == 1:
state_dict['model'] = model[0].state_dict_for_save_checkpoint() model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else: else:
for i in range(len(model)): for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() model_state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint()
# RNG states.
if not args.no_save_rng:
model_state_dict["rng_state"] = rng_state
# Collect optimizer state. (Optimizer is saved separately from the model, due
# to the conflicting data pattern when using the distributed optimizer.)
optim_state_dict = {}
if not args.no_save_optim \
and (not torch.distributed.is_initialized()
or mpu.get_data_parallel_rank() == 0
or args.use_distributed_optimizer):
# Optimizer stuff. # Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None: if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict() optim_state_dict['optimizer'] = optimizer.state_dict()
if opt_param_scheduler is not None: if opt_param_scheduler is not None:
state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict() optim_state_dict['opt_param_scheduler'] = \
opt_param_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
state_dict["rng_state"] = rng_state
# Save. # Save.
checkpoint_name = get_checkpoint_name(args.save, iteration) if args.use_distributed_optimizer:
ensure_directory_exists(checkpoint_name) # Save model separate from optimizer.
torch.save(state_dict, checkpoint_name) if model_state_dict:
ensure_directory_exists(model_checkpoint_name)
torch.save(model_state_dict, model_checkpoint_name)
if optim_state_dict:
ensure_directory_exists(optim_checkpoint_name)
torch.save(optim_state_dict, optim_checkpoint_name)
else:
# Save model and optimizer together.
state_dict = {**model_state_dict, **optim_state_dict}
if state_dict: # only saves if populated (i.e., inherits conditions above)
ensure_directory_exists(model_checkpoint_name)
torch.save(state_dict, model_checkpoint_name)
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
...@@ -322,12 +358,19 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -322,12 +358,19 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
iteration, release = read_metadata(tracker_filename) iteration, release = read_metadata(tracker_filename)
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release) model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(load_dir, iteration,
args.use_distributed_optimizer,
release)
print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}') print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
# Load the checkpoint. # Load the checkpoint.
try: try:
state_dict = torch.load(checkpoint_name, map_location='cpu') model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
if args.use_distributed_optimizer:
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
else:
optim_state_dict = model_state_dict
except ModuleNotFoundError: except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler from megatron.fp16_deprecated import loss_scaler
# For backward compatibility. # For backward compatibility.
...@@ -336,7 +379,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -336,7 +379,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
'megatron.fp16_deprecated.loss_scaler'] 'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler'] 'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu') model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None) sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None) sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException as e: except BaseException as e:
...@@ -344,18 +388,18 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -344,18 +388,18 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(e) print_rank_0(e)
sys.exit() sys.exit()
# set checkpoint version # Set checkpoint version.
set_checkpoint_version(state_dict.get('checkpoint_version', 0)) set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
# Set iteration. # Set iteration.
if args.finetune or release: if args.finetune or release:
iteration = 0 iteration = 0
else: else:
try: try:
iteration = state_dict['iteration'] iteration = model_state_dict['iteration']
except KeyError: except KeyError:
try: # Backward compatible with older checkpoints try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters'] iteration = model_state_dict['total_iters']
except KeyError: except KeyError:
print_rank_0('A metadata file exists but unable to load ' print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format( 'iteration from checkpoint {}, exiting'.format(
...@@ -365,8 +409,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -365,8 +409,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check arguments. # Check arguments.
assert args.consumed_train_samples == 0 assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0 assert args.consumed_valid_samples == 0
if 'args' in state_dict: if 'args' in model_state_dict:
checkpoint_args = state_dict['args'] checkpoint_args = model_state_dict['args']
check_checkpoint_args(checkpoint_args) check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args, args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0) 'consumed_train_samples', 0)
...@@ -378,11 +422,11 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -378,11 +422,11 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Model. # Model.
if len(model) == 1: if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict) model[0].load_state_dict(model_state_dict['model'], strict=strict)
else: else:
for i in range(len(model)): for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict) model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed # Fix up query/key/value matrix ordering if needed
checkpoint_version = get_checkpoint_version() checkpoint_version = get_checkpoint_version()
...@@ -393,12 +437,12 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -393,12 +437,12 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
try: try:
if optimizer is not None: if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer']) optimizer.load_state_dict(optim_state_dict['optimizer'])
if opt_param_scheduler is not None: if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility if 'lr_scheduler' in optim_state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler']) opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
else: else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler']) opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
except KeyError: except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. ' print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent ' 'Specify --no-load-optim or --finetune to prevent '
...@@ -409,13 +453,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -409,13 +453,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# rng states. # rng states.
if not release and not args.finetune and not args.no_load_rng: if not release and not args.finetune and not args.no_load_rng:
try: try:
if 'rng_state' in state_dict: if 'rng_state' in model_state_dict:
# access rng_state for data parallel rank # access rng_state for data parallel rank
if args.data_parallel_random_init: if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
else: else:
rng_state = state_dict['rng_state'][0] rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state']) random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state']) np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state']) torch.set_rng_state(rng_state['torch_rng_state'])
...@@ -426,15 +470,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -426,15 +470,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
mpu.get_cuda_rng_tracker().set_states( mpu.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states']) rng_state['rng_tracker_states'])
else: # backward compatability else: # backward compatability
random.setstate(state_dict['random_rng_state']) random.setstate(model_state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state']) np.random.set_state(model_state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state']) torch.set_rng_state(model_state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state']) torch.cuda.set_rng_state(model_state_dict['cuda_rng_state'])
# Check for empty states array # Check for empty states array
if not state_dict['rng_tracker_states']: if not model_state_dict['rng_tracker_states']:
raise KeyError raise KeyError
mpu.get_cuda_rng_tracker().set_states( mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states']) model_state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent ' 'Specify --no-load-rng or --finetune to prevent '
...@@ -469,12 +513,14 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -469,12 +513,14 @@ def load_biencoder_checkpoint(model, only_query_model=False,
with open(tracker_filename, 'r') as f: with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip()) iteration = int(f.read().strip())
checkpoint_name = get_checkpoint_name(load_path, iteration, False) checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
args.use_distributed_optimizer,
False)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(model_checkpoint_name, map_location='cpu')
ret_state_dict = state_dict['model'] ret_state_dict = state_dict['model']
if only_query_model: if only_query_model:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
import math
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
...@@ -24,18 +25,17 @@ from megatron import mpu ...@@ -24,18 +25,17 @@ from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
class MemoryBuffer: class MemoryBuffer:
def __init__(self, numel, dtype): def __init__(self, numel, numel_padded, dtype):
self.numel = numel self.numel = numel
self.numel_padded = numel_padded
self.dtype = dtype self.dtype = dtype
self.data = torch.zeros(self.numel, self.data = torch.zeros(self.numel_padded,
dtype=self.dtype, dtype=self.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
def zero(self): def zero(self):
"""Reset the buffer to zero.""" """Reset the buffer to zero."""
self.data.zero_() self.data.zero_()
...@@ -121,8 +121,11 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -121,8 +121,11 @@ class DistributedDataParallel(DistributedDataParallelBase):
# the case we use continuous buffers. # the case we use continuous buffers.
# =================================== # ===================================
self._grad_buffers = None self._grad_buffers = None
self._grad_buffer_param_index_map = None
if self.use_contiguous_buffers: if self.use_contiguous_buffers:
self._grad_buffers = {} self._grad_buffers = {}
self._grad_buffer_param_index_map = {}
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Simple function to define buffer type. # Simple function to define buffer type.
def _get_buffer_type(param): def _get_buffer_type(param):
...@@ -139,7 +142,18 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -139,7 +142,18 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Allocate the buffer. # Allocate the buffer.
for dtype, num_elements in type_num_elements.items(): for dtype, num_elements in type_num_elements.items():
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
# If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.)
num_elements_padded = data_parallel_world_size * \
int(math.ceil(num_elements / data_parallel_world_size))
# Allocate grad buffer.
self._grad_buffers[dtype] = MemoryBuffer(num_elements,
num_elements_padded,
dtype)
# Assume the back prop order is reverse the params order, # Assume the back prop order is reverse the params order,
# store the start index for the gradients. # store the start index for the gradients.
...@@ -149,6 +163,12 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -149,6 +163,12 @@ class DistributedDataParallel(DistributedDataParallelBase):
type_num_elements[dtype] -= param.data.nelement() type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get( param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype]) param.data.shape, type_num_elements[dtype])
if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {}
self._grad_buffer_param_index_map[dtype][param] = (
type_num_elements[dtype],
type_num_elements[dtype] + param.data.nelement(),
)
# Backward hook. # Backward hook.
# Accumalation function for the gradients. We need # Accumalation function for the gradients. We need
...@@ -164,6 +184,7 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -164,6 +184,7 @@ class DistributedDataParallel(DistributedDataParallelBase):
grad_acc.register_hook(self._make_param_hook(param)) grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
def _make_param_hook(self, param): def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop.""" """Create the all-reduce hook for backprop."""
# Hook used for back-prop. # Hook used for back-prop.
......
...@@ -680,6 +680,17 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -680,6 +680,17 @@ class ParallelTransformerLayer(MegatronModule):
mlp_bias.expand_as(residual), mlp_bias.expand_as(residual),
residual, residual,
self.hidden_dropout) self.hidden_dropout)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = mpu.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad,
keep_graph = True)
else: else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias, out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout, p=self.hidden_dropout,
......
...@@ -18,6 +18,7 @@ from apex.optimizers import FusedSGD as SGD ...@@ -18,6 +18,7 @@ from apex.optimizers import FusedSGD as SGD
from megatron import get_args from megatron import get_args
from .distrib_optimizer import DistributedOptimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
...@@ -104,7 +105,11 @@ def get_megatron_optimizer(model, ...@@ -104,7 +105,11 @@ def get_megatron_optimizer(model,
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
params_have_main_grad = True params_have_main_grad = True
if args.fp16 or args.bf16: # Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if args.fp16 or args.bf16 or args.use_distributed_optimizer:
# Grad scaler: # Grad scaler:
# if loss-scale is provided, instantiate the constant scaler. # if loss-scale is provided, instantiate the constant scaler.
...@@ -113,9 +118,11 @@ def get_megatron_optimizer(model, ...@@ -113,9 +118,11 @@ def get_megatron_optimizer(model,
# otherwise we are running in bf16 with no loss-scale so # otherwise we are running in bf16 with no loss-scale so
# leave it as None. # leave it as None.
grad_scaler = None grad_scaler = None
# Constant loss scale. # Constant loss scale.
if args.loss_scale: if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale) grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale. # Dynamic loss scale.
else: else:
if args.fp16: if args.fp16:
...@@ -128,16 +135,22 @@ def get_megatron_optimizer(model, ...@@ -128,16 +135,22 @@ def get_megatron_optimizer(model,
hysteresis=args.hysteresis) hysteresis=args.hysteresis)
# Megatron optimizer. # Megatron optimizer.
return Float16OptimizerWithFloat16Params(optimizer, opt_ty = DistributedOptimizer \
if args.use_distributed_optimizer else \
Float16OptimizerWithFloat16Params
return opt_ty(optimizer,
args.clip_grad, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp, args.use_contiguous_buffers_in_local_ddp,
args.fp16,
args.bf16, args.bf16,
grad_scaler) grad_scaler,
model)
# FP32. # FP32.
return FP32Optimizer(optimizer, args.clip_grad, return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp) args.use_contiguous_buffers_in_local_ddp,
model)
...@@ -21,12 +21,13 @@ from torch._six import inf ...@@ -21,12 +21,13 @@ from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron import mpu
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): def clip_grad_norm_fp32(parameters, grads_for_norm,
max_norm, norm_type=2,
model_parallel_group=None):
"""Clips gradient norm of an iterable of parameters whose gradients """Clips gradient norm of an iterable of parameters whose gradients
are in fp32. are in fp32.
...@@ -37,9 +38,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -37,9 +38,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
Arguments: Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized single Tensor that will have gradients normalized
grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
max_norm (float or int): max norm of the gradients max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm. infinity norm.
model_parallel_group (group): given the nature of the distributed
optimizer, this is passed as an argument.
Returns: Returns:
Total norm of the parameters (viewed as a single vector). Total norm of the parameters (viewed as a single vector).
...@@ -47,25 +52,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -47,25 +52,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
if isinstance(grads_for_norm, torch.Tensor):
grads_for_norm = [grads_for_norm]
# Filter parameters based on: # Grads.
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
grads = [] grads = []
grads_for_norm = []
for param in parameters: for param in parameters:
grad_not_none = param.grad is not None if param.grad is not None:
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none:
grad = param.grad.detach()
if grad_not_none:
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor' assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad) grads.append(param.grad.detach())
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
# Norm parameters. # Norm parameters.
max_norm = float(max_norm) max_norm = float(max_norm)
...@@ -79,7 +74,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -79,7 +74,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Take max across all model-parallel GPUs. # Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group()) group=model_parallel_group)
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
else: else:
...@@ -88,12 +83,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -88,12 +83,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Use apex's multi-tensor applier for efficiency reasons. # Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list # Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel. # and performs the operation on that list all in one kernel.
if grads_for_norm:
grad_norm, _ = multi_tensor_applier( grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, amp_C.multi_tensor_l2norm,
dummy_overflow_buf, dummy_overflow_buf,
[grads_for_norm], [grads_for_norm],
False # no per-parameter norm False # no per-parameter norm
) )
else:
grad_norm = torch.cuda.FloatTensor([0])
# Since we will be summing across data parallel groups, # Since we will be summing across data parallel groups,
# we need the pow(norm-type). # we need the pow(norm-type).
total_norm = grad_norm ** norm_type total_norm = grad_norm ** norm_type
...@@ -106,7 +104,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -106,7 +104,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm, torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group()) group=model_parallel_group)
total_norm = total_norm.item() ** (1.0 / norm_type) total_norm = total_norm.item() ** (1.0 / norm_type)
# Scale. # Scale.
...@@ -121,7 +119,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -121,7 +119,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
return total_norm return total_norm
def count_zeros_fp32(parameters): def count_zeros_fp32(parameters, model_parallel_group):
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
...@@ -130,7 +128,7 @@ def count_zeros_fp32(parameters): ...@@ -130,7 +128,7 @@ def count_zeros_fp32(parameters):
# - grad should not be none # - grad should not be none
# - parameter should not be shared # - parameter should not be shared
# - should not be a replica due to tensor model parallelism # - should not be a replica due to tensor model parallelism
total_num_zeros = 0.0 total_num_zeros = torch.cuda.FloatTensor([0.0])
for param in parameters: for param in parameters:
grad_not_none = param.grad is not None grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
...@@ -143,7 +141,8 @@ def count_zeros_fp32(parameters): ...@@ -143,7 +141,8 @@ def count_zeros_fp32(parameters):
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
torch.distributed.all_reduce(total_num_zeros, torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group()) group=model_parallel_group)
total_num_zeros = total_num_zeros.item() total_num_zeros = total_num_zeros.item()
return total_num_zeros return total_num_zeros
This diff is collapsed.
This diff is collapsed.
...@@ -23,7 +23,6 @@ import time ...@@ -23,7 +23,6 @@ import time
_TRAIN_START_TIME = time.time() _TRAIN_START_TIME = time.time()
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args from megatron import get_args
from megatron import get_signal_handler from megatron import get_signal_handler
...@@ -365,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -365,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func,
args = get_args() args = get_args()
model = get_model(model_provider_func, model_type) model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None: if args.load is not None:
...@@ -413,97 +411,44 @@ def train_step(forward_step_func, data_iterator, ...@@ -413,97 +411,44 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer() partition.zero_grad_buffer()
optimizer.zero_grad() optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False) optimizer, timers, forward_only=False)
# Empty unused memory # Empty unused memory.
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# All-reduce layernorm parameters across model parallel nodes # Reduce gradients.
# when sequence parallelism is used timers('backward-reduce-model-grads').start()
if mpu.get_tensor_model_parallel_world_size() > 1 and \ optimizer.reduce_model_grads(args, timers)
args.sequence_parallel: timers('backward-reduce-model-grads').stop()
grads = []
for model_module in model:
unwrapped_model = unwrap_model(
model_module, (torchDDP, LocalDDP, Float16Module))
for param in unwrapped_model.parameters():
if getattr(param, 'sequence_parallel', False):
grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0], unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step() update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop() timers('optimizer').stop()
# Gather params.
if update_successful:
timers('backward-gather-model-params').start()
optimizer.gather_model_params(args, timers)
timers('backward-gather-model-params').stop()
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0], unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration) unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate. # Update learning rate.
if update_successful: if update_successful:
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
...@@ -514,7 +459,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -514,7 +459,7 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
# Empty unused memory # Empty unused memory.
if args.empty_unused_memory_level >= 2: if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -581,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -581,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-send-forward-recv') add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv') add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce') add_to_logging('backward-params-all-reduce')
add_to_logging('backward-layernorm-all-reduce')
add_to_logging('backward-embedding-all-reduce') add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-reduce-model-grads')
add_to_logging('backward-gather-model-params')
add_to_logging('optimizer-copy-to-main-grad') add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf') add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad') add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-count-zeros')
add_to_logging('optimizer-inner-step')
add_to_logging('optimizer-copy-main-to-model-params') add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer') add_to_logging('optimizer')
add_to_logging('batch-generator') add_to_logging('batch-generator')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment