Commit d07d29df authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/distrib-opt-nodupe' into 'main'

Distributed optimizer

See merge request ADLR/megatron-lm!408
parents 55ba1f7b 74ba3797
# Distributed Optimizer
The motivation for the distributed optimizer is to save memory by distributing the optimizer state evenly across data parallel ranks, versus the current method of replicating the optimizer state across data parallel ranks. As described in https://arxiv.org/abs/1910.02054, this branch specifically implements the following:
- [yes] distribute all 'non-overlapping' optimizer state (i.e., model params already in fp32 are NOT distributed)
- [no] distribute model gradients
- [no] distribute model parameters
Theoretical memory savings vary depending on the combination of the model's param dtype and grad dtype. In the current implementation, the theoretical number of bytes per parameter is (where 'd' is the data parallel size):
| | Non-distributed optim | Distributed optim |
| ------ | ------ | ------ |
| float16 param, float16 grads | 20 | 4 + 16/d |
| float16 param, fp32 grads | 18 | 6 + 12/d |
| fp32 param, fp32 grads | 16 | 8 + 8/d |
The implementation of the distributed optimizer is centered on using the contiguous grad buffer for communicating grads & params between the model state and the optimizer state. The grad buffer at any given moment either holds:
1. all model grads
2. a 1/d size _copy_ of the main grads (before copying to the optimizer state)
3. a 1/d size _copy_ of the main params (after copying from the optimizer state)
4. all model params
5. zeros (or None), between iterations
The grad buffer is used for performing reduce-scatter and all-gather operations, for passing grads & params between the model state and optimizer state. With this implementation, no dynamic buffers are allocated.
The figures below illustrate the grad buffer's sharding scheme, and the key steps of the distributed optimizer's param update:
## Data flow
![Data flow](images/distrib_optimizer/data_flow.png)
## Sharding scheme
![Sharding scheme](images/distrib_optimizer/sharding_scheme.png)
## Key steps
_(note: using illustrations above, and assuming fp16 grads)_
- Backward pass finishes (grad buffer holds 16 fp16 grad elements)
- Call reduce-scatter on each DP rank
- Each DP rank now has 4 elements within the grad buffer that are fully reduced (remaining 12 elements are garbage)
- Each DP rank copies its relevant 4 fp16 grad elements from the grad buffer into 4 fp32 main grad elements (separate buffer, owned by the optimizer); i.e.
- DP rank 0 copies elements [0:4]
- DP rank 1 copies elements [4:8]
- DP rank 2 copies elements [8:12]
- DP rank 3 copies elements [12:16]
- Optimizer.step()
- Each DP rank copies its 4 fp32 main (/optimizer) param elements into the corresponding 4 fp16 elements in the grad buffer
- Call all-gather on each DP rank
- Grad buffer now contains all 16, fully updated, fp16 model param elements
- Copy updated model params from grad buffer into their respective param tensors
- (At this point, grad buffer is ready to be zero'd for the next iteration)
......@@ -187,6 +187,12 @@ def parse_args(extra_args_provider=None, defaults={},
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True)
# If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
if args.use_distributed_optimizer:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
# For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch':
args.use_contiguous_buffers_in_local_ddp = False
......@@ -765,6 +771,9 @@ def _add_distributed_args(parser):
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.')
return parser
......
......@@ -81,24 +81,34 @@ def ensure_directory_exists(filename):
os.makedirs(dirname)
def get_checkpoint_name(checkpoints_path, iteration,
release=False):
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
release=False):
"""A unified checkpoint name."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank.
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_tensor_model_parallel_rank()),
'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank()),
'model_optim_rng.pt')
common_path = os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_tensor_model_parallel_rank()))
else:
common_path = os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank()))
if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join(
common_path + "_%03d" % mpu.get_data_parallel_rank(),
"optim.pt")
else:
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
return model_name, optim_name
def get_checkpoint_tracker_filename(checkpoints_path):
......@@ -177,38 +187,64 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
# collect rng state across data parallel ranks
# Collect rng state across data parallel ranks.
rng_state = get_rng_state()
if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Checkpoint file names.
model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)
# Collect args, model, RNG.
model_state_dict = {}
if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
model_state_dict['args'] = args
model_state_dict['checkpoint_version'] = 3.0
model_state_dict['iteration'] = iteration
if len(model) == 1:
state_dict['model'] = model[0].state_dict_for_save_checkpoint()
model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if opt_param_scheduler is not None:
state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
model_state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint()
# RNG states.
if not args.no_save_rng:
state_dict["rng_state"] = rng_state
model_state_dict["rng_state"] = rng_state
# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
# Collect optimizer state. (Optimizer is saved separately from the model, due
# to the conflicting data pattern when using the distributed optimizer.)
optim_state_dict = {}
if not args.no_save_optim \
and (not torch.distributed.is_initialized()
or mpu.get_data_parallel_rank() == 0
or args.use_distributed_optimizer):
# Optimizer stuff.
if optimizer is not None:
optim_state_dict['optimizer'] = optimizer.state_dict()
if opt_param_scheduler is not None:
optim_state_dict['opt_param_scheduler'] = \
opt_param_scheduler.state_dict()
# Save.
if args.use_distributed_optimizer:
# Save model separate from optimizer.
if model_state_dict:
ensure_directory_exists(model_checkpoint_name)
torch.save(model_state_dict, model_checkpoint_name)
if optim_state_dict:
ensure_directory_exists(optim_checkpoint_name)
torch.save(optim_state_dict, optim_checkpoint_name)
else:
# Save model and optimizer together.
state_dict = {**model_state_dict, **optim_state_dict}
if state_dict: # only saves if populated (i.e., inherits conditions above)
ensure_directory_exists(model_checkpoint_name)
torch.save(state_dict, model_checkpoint_name)
# Wait so everyone is done (necessary)
if torch.distributed.is_initialized():
......@@ -322,12 +358,19 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
iteration, release = read_metadata(tracker_filename)
# Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(load_dir, iteration,
args.use_distributed_optimizer,
release)
print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
if args.use_distributed_optimizer:
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
else:
optim_state_dict = model_state_dict
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
......@@ -336,7 +379,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException as e:
......@@ -344,18 +388,18 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(e)
sys.exit()
# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
# Set checkpoint version.
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
# Set iteration.
if args.finetune or release:
iteration = 0
else:
try:
iteration = state_dict['iteration']
iteration = model_state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
iteration = model_state_dict['total_iters']
except KeyError:
print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
......@@ -365,8 +409,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
checkpoint_args = state_dict['args']
if 'args' in model_state_dict:
checkpoint_args = model_state_dict['args']
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
......@@ -378,11 +422,11 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Model.
if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
model[0].load_state_dict(model_state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed
checkpoint_version = get_checkpoint_version()
......@@ -393,12 +437,12 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
optimizer.load_state_dict(optim_state_dict['optimizer'])
if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
if 'lr_scheduler' in optim_state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
......@@ -409,13 +453,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
if 'rng_state' in state_dict:
if 'rng_state' in model_state_dict:
# access rng_state for data parallel rank
if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = state_dict['rng_state'][0]
rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
......@@ -426,15 +470,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
mpu.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
random.setstate(model_state_dict['random_rng_state'])
np.random.set_state(model_state_dict['np_rng_state'])
torch.set_rng_state(model_state_dict['torch_rng_state'])
torch.cuda.set_rng_state(model_state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
if not model_state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
model_state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
......@@ -469,12 +513,14 @@ def load_biencoder_checkpoint(model, only_query_model=False,
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
args.use_distributed_optimizer,
False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(model_checkpoint_name, map_location='cpu')
ret_state_dict = state_dict['model']
if only_query_model:
......
......@@ -15,6 +15,7 @@
from abc import ABC
from abc import abstractmethod
import math
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
......@@ -24,18 +25,17 @@ from megatron import mpu
from .module import MegatronModule
class MemoryBuffer:
def __init__(self, numel, dtype):
def __init__(self, numel, numel_padded, dtype):
self.numel = numel
self.numel_padded = numel_padded
self.dtype = dtype
self.data = torch.zeros(self.numel,
self.data = torch.zeros(self.numel_padded,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
def zero(self):
"""Reset the buffer to zero."""
self.data.zero_()
......@@ -121,8 +121,11 @@ class DistributedDataParallel(DistributedDataParallelBase):
# the case we use continuous buffers.
# ===================================
self._grad_buffers = None
self._grad_buffer_param_index_map = None
if self.use_contiguous_buffers:
self._grad_buffers = {}
self._grad_buffer_param_index_map = {}
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Simple function to define buffer type.
def _get_buffer_type(param):
......@@ -139,7 +142,18 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Allocate the buffer.
for dtype, num_elements in type_num_elements.items():
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
# If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.)
num_elements_padded = data_parallel_world_size * \
int(math.ceil(num_elements / data_parallel_world_size))
# Allocate grad buffer.
self._grad_buffers[dtype] = MemoryBuffer(num_elements,
num_elements_padded,
dtype)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
......@@ -149,6 +163,12 @@ class DistributedDataParallel(DistributedDataParallelBase):
type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {}
self._grad_buffer_param_index_map[dtype][param] = (
type_num_elements[dtype],
type_num_elements[dtype] + param.data.nelement(),
)
# Backward hook.
# Accumalation function for the gradients. We need
......@@ -164,6 +184,7 @@ class DistributedDataParallel(DistributedDataParallelBase):
grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc)
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
......
......@@ -680,6 +680,17 @@ class ParallelTransformerLayer(MegatronModule):
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = mpu.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad,
keep_graph = True)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
......
......@@ -18,6 +18,7 @@ from apex.optimizers import FusedSGD as SGD
from megatron import get_args
from .distrib_optimizer import DistributedOptimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
......@@ -104,7 +105,11 @@ def get_megatron_optimizer(model,
if args.DDP_impl == 'local':
params_have_main_grad = True
if args.fp16 or args.bf16:
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if args.fp16 or args.bf16 or args.use_distributed_optimizer:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
......@@ -113,9 +118,11 @@ def get_megatron_optimizer(model,
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale.
else:
if args.fp16:
......@@ -128,16 +135,22 @@ def get_megatron_optimizer(model,
hysteresis=args.hysteresis)
# Megatron optimizer.
return Float16OptimizerWithFloat16Params(optimizer,
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.bf16,
grad_scaler)
opt_ty = DistributedOptimizer \
if args.use_distributed_optimizer else \
Float16OptimizerWithFloat16Params
return opt_ty(optimizer,
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.fp16,
args.bf16,
grad_scaler,
model)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp)
args.use_contiguous_buffers_in_local_ddp,
model)
......@@ -21,12 +21,13 @@ from torch._six import inf
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
def clip_grad_norm_fp32(parameters, grads_for_norm,
max_norm, norm_type=2,
model_parallel_group=None):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
......@@ -37,9 +38,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
model_parallel_group (group): given the nature of the distributed
optimizer, this is passed as an argument.
Returns:
Total norm of the parameters (viewed as a single vector).
......@@ -47,25 +52,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
if isinstance(grads_for_norm, torch.Tensor):
grads_for_norm = [grads_for_norm]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
# Grads.
grads = []
grads_for_norm = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none:
grad = param.grad.detach()
if grad_not_none:
# Make sure the grads are in fp32
if param.grad is not None:
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
grads.append(param.grad.detach())
# Norm parameters.
max_norm = float(max_norm)
......@@ -79,7 +74,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
group=model_parallel_group)
total_norm = total_norm_cuda[0].item()
else:
......@@ -88,12 +83,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False # no per-parameter norm
)
if grads_for_norm:
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False # no per-parameter norm
)
else:
grad_norm = torch.cuda.FloatTensor([0])
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
......@@ -106,7 +104,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
group=model_parallel_group)
total_norm = total_norm.item() ** (1.0 / norm_type)
# Scale.
......@@ -121,7 +119,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
return total_norm
def count_zeros_fp32(parameters):
def count_zeros_fp32(parameters, model_parallel_group):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
......@@ -130,7 +128,7 @@ def count_zeros_fp32(parameters):
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros = 0.0
total_num_zeros = torch.cuda.FloatTensor([0.0])
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
......@@ -143,7 +141,8 @@ def count_zeros_fp32(parameters):
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
group=model_parallel_group)
total_num_zeros = total_num_zeros.item()
return total_num_zeros
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron distributed optimizer."""
import math
import torch
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
class Range:
"""
A range represents a start and end points for indexing a shard
from a full tensor.
"""
def __init__(self, start, end):
self.start = start
self.end = end
self.size = end - start
def normalize(self, start = 0):
return Range(start, start + self.size)
def __str__(self):
return "%d,%d [%d]" % (self.start, self.end, self.size)
class DistributedOptimizer(MixedPrecisionOptimizer):
"""Distributed optimizer, for all data types (fp16, bf16, and fp32).
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
"""
@classmethod
def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
"""
Build mapping from param reference to grad buffer shard ranges.
This method builds a mapping from parameter references to grad
buffer shard ranges, specific to each data-parallel (DP) rank's
set of 'owned' parameters. Each grad buffer (padded to be an even
multiple of DP-world-size) is conceptually divided into DP-world-size
contiguous regions, where each DP rank 'owns' a contiguous regions.
Ownership in this sense means DP rank is responsible for reducing
the relevant subset of grads, and updating the relevant subset of
params.
This conceptual partitioning of the grad buffer does NOT respect
parameter boundaries, and as such it is assumed that each created
range references a shard (or subset) of the full parameter. It is
easiest to think of each DP rank as operating (i.e., reducing,
gathering) purely on views into the grad buffer, for all model-to-
main & main-to-model operations.
This method creates three ranges:
- The param's range within the entire grad buffer (i.e., world index).
- The param's range within the DP rank's local view of the grad buffer.
- The param's range within itself (i.e., its shard).
"""
# Param range map.
param_world_index_map = model._grad_buffer_param_index_map[dtype]
param_range_map = {}
for param, param_world_indexes in param_world_index_map.items():
# Param range.
param_world_start, param_world_end = param_world_indexes
param_local_start = max(
0,
param_world_start - gbuf_world_range.start)
param_local_end = min(
gbuf_world_range.size,
param_world_end - gbuf_world_range.start)
# Add param, if within local gbuf range.
if param_local_end > param_local_start:
param_local_range = Range(param_local_start, param_local_end)
param_world_range = param_local_range.normalize(
param_local_start + gbuf_world_range.start)
sub_param_start = max(0, gbuf_world_range.start-param_world_start)
sub_param_range = param_local_range.normalize(sub_param_start)
param_range_map[param] = {
"gbuf_world" : param_world_range,
"gbuf_local" : param_local_range,
"param" : sub_param_range,
}
return param_range_map
@classmethod
def build_model_gbuf_range(cls, model, dtype):
"""
Build mapping between params and their grad buffers.
This method does the initial setup for the method above. This setup
includes determining the shard ranges into the DDP's grad buffer for
each data-parallel (DP) rank. Each DP rank keeps range info for
all other DP ranks, for the purpose of creating args for
reduce-scatter and all-gather.
"""
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer range.
grad_buffer = model._grad_buffers[dtype]
gbuf_size = grad_buffer.numel
max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
# All world ranges. (i.e., across all data parallel ranks)
gbuf_world_all_ranges = []
for r in range(data_parallel_world_size):
gbuf_world_start = r * max_gbuf_range_size
gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
gbuf_world_all_ranges.append(gbuf_world_range)
# Local DP's ranges.
gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
gbuf_local_range = gbuf_world_range.normalize()
# Get each param's ranges.
param_range_map = cls.build_model_gbuf_param_range_map(model,
dtype,
gbuf_world_range)
# Group into dict.
data = {
"local" : gbuf_local_range,
"world" : gbuf_world_range,
"world_all" : gbuf_world_all_ranges,
"param_map" : param_range_map,
"max_range_size" : max_gbuf_range_size,
}
return data
@classmethod
def build_model_gbuf_range_map(cls, model):
"""
Create param-to-grad-buffer mappings, for grad buffer data types
within a specific virtual model.
"""
return {
dtype : cls.build_model_gbuf_range(model, dtype)
for dtype in model._grad_buffers
}
@classmethod
def build_model_param_gbuf_map(cls, model_gbuf_ranges):
"""
Create a reverse of the model_gbuf_ranges, for referencing in
opposite direction.
"""
param_gbuf_map = {}
for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
for dtype, gbuf_range_map in model_gbuf_range_map.items():
for param, param_range_map in gbuf_range_map["param_map"].items():
param_gbuf_map[param] = (model_index, dtype)
return param_gbuf_map
@classmethod
def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
"""
Create optimizer groups.
Given the set of parameter shard ranges that are owned by the current
data-parallel (DP) rank, gather the set of parameters that will be
used (in the method below) to create the current DP's optimizer
groups.
"""
num_groups = len(param_groups)
# Param group map.
param_group_map = {}
for group_index, group in enumerate(param_groups):
for param in group["params"]:
assert param.requires_grad
param_group_map[param] = group_index
# Optimizer group ranges.
group_ranges = [ {"params": []} for _ in param_groups ]
for model_gbuf_range_map in model_gbuf_ranges:
for dtype, gbuf_range_map in model_gbuf_range_map.items():
for param in gbuf_range_map["param_map"]:
group_index = param_group_map[param]
group_range = group_ranges[group_index]
group_range["params"].append(param)
# Squeeze zero-size group ranges.
for group_index, group_range in enumerate(group_ranges):
group_range["orig_group"] = param_groups[group_index]
group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
return group_ranges
@classmethod
def build_model_and_main_param_groups(cls,
model_gbuf_ranges,
param_gbuf_map,
opt_group_ranges):
"""
Create main parameter groups needed for the optimizer step.
These groups encompass both: 1) groups used by this class, for
reducing/gather, and 2) groups used by the inner optimizer for the
parameter update. Given that the conceptual grad buffer partitioning
(created in earlier method) doesn't respect parameter boundaries,
the optimizer operates on shards of the model parameters, rather than
the full parameters.
"""
# Parameter groups:
# model_float16_groups: original float16 parameters
# model_fp32_groups: original fp32 parameters
# shard_float16_groups: shards of original float16 parameters
# shard_fp32_groups: shards of original fp32 parameters
# shard_fp32_from_float16_groups: fp32 copy of float16 parameters
model_float16_groups = []
model_fp32_groups = []
shard_float16_groups = []
shard_fp32_groups = []
shard_fp32_from_float16_groups = []
# Allocate (or slice) each group's param shard.
for group_index, group_range in enumerate(opt_group_ranges):
# Params of this group.
model_float16_params_this_group = []
model_fp32_params_this_group = []
shard_float16_params_this_group = []
shard_fp32_params_this_group = []
shard_fp32_from_float16_params_this_group = []
model_float16_groups.append(model_float16_params_this_group)
model_fp32_groups.append(model_fp32_params_this_group)
shard_float16_groups.append(shard_float16_params_this_group)
shard_fp32_groups.append(shard_fp32_params_this_group)
shard_fp32_from_float16_groups.append(
shard_fp32_from_float16_params_this_group)
for model_param in group_range["params"]:
assert model_param.requires_grad
model_index, dtype = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype]
param_range = gbuf_range["param_map"][model_param]["param"]
# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor',
'torch.cuda.BFloat16Tensor']:
# Clone model -> main.
shard_model_param = model_param.detach().view(-1) \
[param_range.start:param_range.end]
shard_main_param = shard_model_param.clone().float()
mpu.copy_tensor_model_parallel_attributes(
shard_model_param, model_param)
mpu.copy_tensor_model_parallel_attributes(
shard_main_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared
# Add to group.
model_float16_params_this_group.append(model_param)
shard_float16_params_this_group.append(shard_model_param)
shard_fp32_from_float16_params_this_group.append(shard_main_param)
# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1) \
[param_range.start:param_range.end]
model_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param)
mpu.copy_tensor_model_parallel_attributes(
shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
else:
raise TypeError('Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type()))
# Update optimizer's params.
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_fp32_from_float16_params_this_group,
]
return (
model_float16_groups,
model_fp32_groups,
shard_float16_groups,
shard_fp32_groups,
shard_fp32_from_float16_groups,
)
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models):
"""
See top of class definition for argument descriptions.
The steps in this method create the core mapping between DDP grad
buffers, parameters, and parameter shard ranges, that is needed for
converting between model param indexes and main parameter shard
indexes. This method also updates the optimizer parameter groups
with the newly created shards.
"""
super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models)
# Verify that contiguous buffers are being used.
# - Note: this should already be checked in arguments.py.
assert use_contiguous_buffers_in_local_ddp
# Model grad buffer ranges.
self.model_gbuf_ranges = []
for model_index, model in enumerate(self.models):
self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
self.model_param_gbuf_map = \
self.build_model_param_gbuf_map(self.model_gbuf_ranges)
# Optimizer ranges.
self.opt_group_ranges = self.build_optimizer_group_ranges(
self.optimizer.param_groups,
self.model_gbuf_ranges)
# Allocate main param shards.
(
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups,
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups,
) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
self.model_param_gbuf_map,
self.opt_group_ranges)
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self.optimizer.param_groups = \
[ g["orig_group"] for g in self.opt_group_ranges ]
self.optimizer.load_state_dict(self.optimizer.state_dict())
def get_model_param_range_map(self, param):
"""
Given a model param, get the index sub-range of the param that this
data-parallel rank owns.
"""
model_index, dtype = self.model_param_gbuf_map[param]
gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
param_range_map = gbuf_range_map["param_map"][param]
return param_range_map
def get_model_parallel_group(self):
"""
With the distributed optimizer, the model parallel group is the
entire world.
"""
return None
def state_dict(self):
"""
The state dict must contain the fp32-from-float16 shards.
"""
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['shard_fp32_from_float16_groups'] = \
self.shard_fp32_from_float16_groups
return state_dict
def load_state_dict(self, state_dict):
"""
Load the state dict.
"""
# Optimizer.
optimizer_key = 'optimizer'
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
print_rank_0('***WARNING*** loading optimizer from '
'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler.
if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
print_rank_0('***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...')
# Copy data for the main params.
for current_group, saved_group in zip(
self.shard_fp32_from_float16_groups,
state_dict["shard_fp32_from_float16_groups"]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
def zero_grad(self, set_to_none=True):
"""
Zero grads.
We only need to zero the model related parameters, i.e.,
model_float16_groups & model_fp32_groups. We additionally zero
the remaining groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point.
"""
for groups in (
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups, # throws grad-access warning
self.shard_fp32_from_float16_groups):
for group in groups:
_zero_grad_group_helper(group, set_to_none)
def get_model_grad_buffer_dp_views(self):
"""
Get shard views of each of the DDP's grad buffers.
In this nested list, the top level is grouped by the virtual model
index and the grad buffer's data type. The sub-level is a list of
shards of that grad buffer, where each shard in the list represents
a contiguous view of the grad buffer, that is owned by a data-parallel
rank. The shard boundary does not respect parameter boundaries, and
so the elements of some parameters are split across data parallel
ranks.
Additionally, return references to the entire grad buffers, for use
in _reduce_scatter_base and _all_gather_base.
"""
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views.
gbuf_view_items = []
for model_index, model in enumerate(self.models):
for dtype, gbuf in model._grad_buffers.items():
assert gbuf.numel_padded % data_parallel_world_size == 0
shard_size = int(gbuf.numel_padded / data_parallel_world_size)
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)]
gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
return gbuf_view_items
def reduce_model_grads(self, args, timers):
"""
Reduce-scatter model grads.
The DDP's grad buffer is used for the reduce-scatter, and thus no
tensors are dynamically allocated.
Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) layernorm grads, 2) all
grads, 3) embedding grads.
"""
# All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start()
self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop()
# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
# Reduce-scatter setup.
timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
# Scale grad buffers by '1 / data_parallel_world_size'.
for model in self.models:
for dtype, gbuf in model._grad_buffers.items():
gbuf.data /= data_parallel_world_size
# Reduce-scatter all grads.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
for index, (model_index, dtype, gbuf, gbuf_views) \
in enumerate(gbuf_view_items):
torch.distributed._reduce_scatter_base(
gbuf_views[data_parallel_rank],
gbuf,
group = data_parallel_group,
)
timers('backward-params-all-reduce').stop()
def gather_model_params(self, args, timers):
"""
All-gather updated model params.
The DDP's grad buffer is used for the all-gather, and thus no
tensors are dynamically allocated. After the all-gather, the params
can be copied from param.main_grad to param.
"""
timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
# All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
for index, (model_index, dtype, gbuf, gbuf_views) \
in enumerate(gbuf_view_items):
torch.distributed._all_gather_base(
gbuf,
gbuf_views[data_parallel_rank],
group = data_parallel_group,
)
# Each model param now contains its updated values in its
# '.main_grad' field.
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map:
param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
def _collect_main_grad_data_for_unscaling(self):
"""
Note: this should be equivalent to the float-16 optimizer's method,
but writtent differently, so the two should be combined.
"""
return [
param.grad.data
for group in self.optimizer.param_groups
for param in group["params"]
]
def _get_model_and_main_params_data_float16(self):
"""
Get aligned list of model and main params.
"""
model_data = []
main_data = []
for model_group, main_group in zip(self.shard_float16_groups,
self.shard_fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_model_grads_to_main_grads(self):
"""
Copy model grads to main grads.
Since this step follows a reduce-scatter through the DDP's grad
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
# Utility method for copying group grads.
def copy_group_grads(model_groups, shard_main_groups):
for model_group, shard_main_group in zip(model_groups,
shard_main_groups):
for model_param, shard_main_param in zip(model_group,
shard_main_group):
param_range_map = self.get_model_param_range_map(model_param)
param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement()
model_grad = model_param.main_grad
shard_model_grad = model_grad.view(-1) \
[param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float()
# Copy model groups to shard groups.
copy_group_grads(self.model_float16_groups,
self.shard_fp32_from_float16_groups)
copy_group_grads(self.model_fp32_groups,
self.shard_fp32_groups)
def _copy_main_params_to_model_params(self):
"""
Copy main params to model params.
Since this step is followed by an all-gather through the DDP's grad
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
# Utility method for copying group params.
def copy_group_params(shard_main_groups, model_groups):
for shard_main_group, model_group in zip(shard_main_groups,
model_groups):
for shard_main_param, model_param in zip(shard_main_group,
model_group):
param_range_map = self.get_model_param_range_map(model_param)
param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement()
model_grad = model_param.main_grad
shard_model_grad = model_grad.view(-1) \
[param_range.start:param_range.end]
shard_model_grad.data.copy_(shard_main_param)
# Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups,
self.model_float16_groups)
copy_group_params(self.shard_fp32_groups,
self.model_fp32_groups)
......@@ -17,15 +17,20 @@
from abc import ABC
from abc import abstractmethod
import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
......@@ -69,7 +74,8 @@ class MegatronOptimizer(ABC):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad,
use_contiguous_buffers_in_local_ddp):
use_contiguous_buffers_in_local_ddp,
models):
"""Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer
......@@ -80,10 +86,15 @@ class MegatronOptimizer(ABC):
self.params_have_main_grad = params_have_main_grad
self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
# 'models' are retained for access to the contiguous grad buffers.
# (see distributed optimizer)
self.models = models
if self.use_contiguous_buffers_in_local_ddp:
assert self.params_have_main_grad, \
"use of contiguous buffer requires that params have main grad"
def get_parameters(self):
params = []
for param_group in self.optimizer.param_groups:
......@@ -92,14 +103,42 @@ class MegatronOptimizer(ABC):
return params
def get_main_grads_for_grad_norm(self):
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
params = self.get_parameters()
grads_for_norm = []
for param in params:
grad = param.grad
grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
return grads_for_norm
def get_model_parallel_group(self):
"""Default returned here, but the distributed optimizer overrides this."""
return mpu.get_model_parallel_group()
def clip_grad_norm(self, clip_grad):
params = self.get_parameters()
return clip_grad_norm_fp32(params, clip_grad)
grads_for_norm = self.get_main_grads_for_grad_norm()
return clip_grad_norm_fp32(
params, grads_for_norm, clip_grad,
model_parallel_group=self.get_model_parallel_group())
def count_zeros(self):
params = self.get_parameters()
return count_zeros_fp32(params)
return count_zeros_fp32(params,
model_parallel_group=self.get_model_parallel_group())
@abstractmethod
......@@ -118,11 +157,6 @@ class MegatronOptimizer(ABC):
return self.get_loss_scale() * loss
@abstractmethod
def step(self):
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
......@@ -166,9 +200,119 @@ class MegatronOptimizer(ABC):
param_groups = property(_get_param_groups, _set_param_groups)
@abstractmethod
def step(self, args, timers):
pass
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
def gather_model_params(self, args, timers):
"""
For the case of a non-distributed-optimizer, there is nothing to
do here.
"""
pass
def allreduce_word_embedding_grads(self, args):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings
parameters stay in sync. This should only run for models that support
pipelined model parallelism (BERT and GPT-2).
"""
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = self.models[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = self.models[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = self.models[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
def allreduce_position_embedding_grads(self, args):
"""
All-reduce position_embeddings grad across first (encoder) and
split (decoder) stages to ensure that position embeddings parameters
stay in sync. This should only run for T5 models with pipeline
parallelism.
"""
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = self.models[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
def allreduce_embedding_grads(self, args):
"""All-reduce both word and position embeddings."""
self.allreduce_word_embedding_grads(args)
self.allreduce_position_embedding_grads(args)
def allreduce_layernorm_grads(self, args):
"""All-reduce layernorm grads (for sequence parallelism)."""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if mpu.get_tensor_model_parallel_world_size() > 1 and \
args.sequence_parallel:
grads = []
for model_module in self.models:
unwrapped_model = unwrap_model(
model_module, (torchDDP, LocalDDP, Float16Module))
for param in unwrapped_model.parameters():
if getattr(param, 'sequence_parallel', False):
grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
def reduce_model_grads(self, args, timers):
"""All-reduce all grads, and all-reduce embeddings."""
# All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start()
self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop()
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model in self.models:
model.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
class MixedPrecisionOptimizer(MegatronOptimizer):
"""Base class for both the float-16 and the distributed optimizer.
Arguments:
optimizer: base optimizer such as Adam or SGD
......@@ -184,27 +328,36 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler):
fp16, bf16, grad_scaler,
models):
super(Float16OptimizerWithFloat16Params, self).__init__(
super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp)
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
models)
self.fp16 = fp16
self.bf16 = bf16
self.grad_scaler = grad_scaler
# None grad scaler is only supported for bf16.
if self.grad_scaler is None:
assert self.bf16, 'fp16 expects a grad scaler.'
assert not self.fp16, 'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
......@@ -225,6 +378,131 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if self.grad_scaler is None:
self._scale_one = torch.cuda.FloatTensor([1.0])
def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale
def reload_model_params(self):
self._copy_model_params_to_main_params()
def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling()
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances.
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=self.get_model_parallel_group())
# Check for nan.
found_inf_flag = (self.found_inf.item() > 0)
return found_inf_flag
@torch.no_grad()
def step(self, args, timers):
# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# Count the zeros in the grads.
timers('optimizer-count-zeros').start()
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
timers('optimizer-count-zeros').stop()
# Step the optimizer.
timers('optimizer-inner-step').start()
self.optimizer.step()
timers('optimizer-inner-step').stop()
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop()
# Successful update.
return True, grad_norm, num_zeros_in_grad
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models):
super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models)
# ======================
# main parameter stuff
# ======================
......@@ -259,6 +537,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
......@@ -296,10 +575,34 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
_zero_grad_group_helper(group, set_to_none)
def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale
def _collect_main_grad_data_for_unscaling(self):
main_grads = []
# fp32 params from float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
return main_grads
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_model_grads_to_main_grads(self):
......@@ -333,43 +636,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None
def _unscale_main_grads_and_check_for_nan(self):
main_grads = []
# fp32 params fromm float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances.
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
# Check for nan.
found_inf_flag = (self.found_inf.item() > 0)
return found_inf_flag
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_main_params_to_model_params(self):
# Only needed for the float16 params.
......@@ -385,60 +651,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
overflow_buf=self._dummy_overflow_buf)
def reload_model_params(self):
self._copy_model_params_to_main_params()
@torch.no_grad()
def step(self):
timers = get_timers()
# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Step the optimizer.
self.optimizer.step()
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop()
# Successful update.
return True, grad_norm, num_zeros_in_grad
def state_dict(self):
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
......@@ -480,17 +692,18 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
current_param.data.copy_(saved_param.data)
class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad,
use_contiguous_buffers_in_local_ddp):
use_contiguous_buffers_in_local_ddp,
models):
super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp)
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
models)
self._scale = torch.cuda.FloatTensor([1.0])
......@@ -507,11 +720,12 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad()
def step(self):
def step(self, args, timers):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
# Copy main_grads to grads.
timers('optimizer-copy-to-main-grad').start()
if self.params_have_main_grad:
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
......@@ -522,18 +736,25 @@ class FP32Optimizer(MegatronOptimizer):
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_local_ddp:
param.main_grad = None
timers('optimizer-copy-to-main-grad').stop()
# Clip gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
timers('optimizer-count-zeros').start()
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
timers('optimizer-count-zeros').stop()
# Update parameters.
timers('optimizer-inner-step').start()
self.optimizer.step()
timers('optimizer-inner-step').stop()
# No overflow for FP32 optimizer.
return True, grad_norm, num_zeros_in_grad
......
......@@ -23,7 +23,6 @@ import time
_TRAIN_START_TIME = time.time()
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args
from megatron import get_signal_handler
......@@ -365,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func,
args = get_args()
model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None:
......@@ -413,97 +411,44 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer()
optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# Empty unused memory
# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if mpu.get_tensor_model_parallel_world_size() > 1 and \
args.sequence_parallel:
grads = []
for model_module in model:
unwrapped_model = unwrap_model(
model_module, (torchDDP, LocalDDP, Float16Module))
for param in unwrapped_model.parameters():
if getattr(param, 'sequence_parallel', False):
grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Reduce gradients.
timers('backward-reduce-model-grads').start()
optimizer.reduce_model_grads(args, timers)
timers('backward-reduce-model-grads').stop()
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop()
# Gather params.
if update_successful:
timers('backward-gather-model-params').start()
optimizer.gather_model_params(args, timers)
timers('backward-gather-model-params').stop()
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
......@@ -514,7 +459,7 @@ def train_step(forward_step_func, data_iterator,
else:
skipped_iter = 1
# Empty unused memory
# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
......@@ -581,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-layernorm-all-reduce')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-reduce-model-grads')
add_to_logging('backward-gather-model-params')
add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-count-zeros')
add_to_logging('optimizer-inner-step')
add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer')
add_to_logging('batch-generator')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment