Commit b4bc51b1 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Jared Casper
Browse files

Bfloat with fp32 grad acc

parent 87b8b9dc
......@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={},
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# No fusion is support for bfloat for now
assert not args.masked_softmax_fusion
assert not args.bias_gelu_fusion
assert not args.bias_dropout_fusion
if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True
if args.dataloader_type is None:
args.dataloader_type = 'single'
......@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={},
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16, \
'residual connection in fp32 only supported when using fp16.'
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \
......@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser):
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.')
group.add_argument('--bf16', action='store_true',
help='Run model in bfloat16 mode.')
group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
......@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser):
help='Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.')
group.add_argument('--fp32-allreduce', action='store_true',
help='All-reduce in fp32')
group.add_argument('--accumulate-allreduce-grads-in-fp32',
action='store_true',
help='Gradient accumulation and all-reduce in fp32.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
......@@ -577,6 +595,9 @@ def _add_distributed_args(parser):
choices=['local', 'torch'],
help='which DistributedDataParallel implementation '
'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true',
help='If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.' )
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline')
......
......@@ -16,11 +16,13 @@
_LAYER_NORM = None
def import_layernorm(fp32_residual_connection):
def import_layernorm(fp32_residual_connection, bf16):
global _LAYER_NORM
if not _LAYER_NORM:
if fp32_residual_connection:
if bf16:
from torch.nn import LayerNorm
elif fp32_residual_connection:
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
else:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
......@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel,
GPTModelIntermediateStage,
GPTModelLastStage)
from .language_model import get_language_model
from .module import FP16Module
from .module import Float16Module
......@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
LayerNorm = import_layernorm(args.fp32_residual_connection)
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
......
......@@ -13,100 +13,206 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from abc import abstractmethod
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
from megatron import get_args
from megatron import mpu
from .module import MegatronModule
class DistributedDataParallel(MegatronModule):
def __init__(self, module):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
class MemoryBuffer:
def __init__(self, numel, dtype):
self.numel = numel
self.dtype = dtype
self.data = torch.zeros(self.numel,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
def zero(self):
"""Reset the buffer to zero."""
self.data.zero_()
def get(self, shape, start_index):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index = start_index + shape.numel()
assert end_index <= self.numel, \
'requested tensor is out of the buffer range.'
buffer_tensor = self.data[start_index:end_index]
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
class DistributedDataParallelBase(MegatronModule, ABC):
"""Abstract class for DDP."""
def __init__(self, module):
super(DistributedDataParallelBase, self).__init__()
# Keep a pointer to the model.
self.module = module
self.data_parallel_group = mpu.get_data_parallel_group()
def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
if(self.needs_reduction):
self.needs_reduction = False
buckets = {}
for name, param in self.module.named_parameters():
if param.requires_grad and param.grad is not None:
tp = (param.data.type())
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if self.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
self.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce:
coalesced = coalesced.float()
if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group)
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
self.hook_handles = []
self.hooks = []
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self.allreduce_params = allreduce_params
@abstractmethod
def allreduce_gradients(self):
pass
def forward(self, *inputs, **kwargs):
self.needs_reduction = True
return self.module(*inputs, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False):
#[h.remove() for h in self.hook_handles]
sd = self.module.state_dict(destination, prefix, keep_vars)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return self.module.state_dict(destination, prefix, keep_vars)
return sd
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
if len(buffers) > 0:
# cross-node buffer sync
flat_buffers = _flatten_dense_tensors(buffers)
dist.broadcast(flat_buffers, 0)
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
buf.copy_(synced)
def train(self, mode=True):
# Clear NCCL communicator and CUDA event cache of the default group ID,
# These cache will be recreated at the later call. This is currently a
# work-around for a potential NCCL deadlock.
if dist._backend == dist.dist_backend.NCCL:
dist._clear_group_cache()
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
class DistributedDataParallel(DistributedDataParallelBase):
"""DDP with contiguous buffers options to storre and accumulate gradients.
This class:
- has the potential to reduce memory fragmentation.
- provides the option to do the gradient accumulation
in a type other than the params type (for example fp32)
Arguments:
module: input model.
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
and the gradient all-reduce all in in float32. If this option is
true, we require `use_contiguous_buffers` to be true too.
use_contiguous_buffers: if true, use a contiguous buffer to store the
gradients.
"""
def __init__(self, module,
accumulate_allreduce_grads_in_fp32,
use_contiguous_buffers):
super(DistributedDataParallel, self).__init__(module)
self.accumulate_allreduce_grads_in_fp32 \
= accumulate_allreduce_grads_in_fp32
self.use_contiguous_buffers = use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if self.accumulate_allreduce_grads_in_fp32:
assert self.use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self._grad_buffers = None
if self.use_contiguous_buffers:
self._grad_buffers = {}
# Simple function to define buffer type.
def _get_buffer_type(param):
return torch.float if \
self.accumulate_allreduce_grads_in_fp32 else param.dtype
# First calculate total number of elements per type.
type_num_elements = {}
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
+ param.data.nelement()
# Allocate the buffer.
for dtype, num_elements in type_num_elements.items():
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self.grad_accs = []
# Loop over all the parameters in the model.
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator functtion.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc)
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad.data is not None:
param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory.
param.grad = None
return param_hook
def zero_grad_buffer(self):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert self._grad_buffers is not None, 'buffers are not initialized.'
for _, buffer_ in self._grad_buffers.items():
buffer_.zero()
def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if self._grad_buffers is not None:
for _, buffer_ in self._grad_buffers.items():
buffer_.data /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
buffer_.data, group=mpu.get_data_parallel_group())
else:
# Otherwise, bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
coalesced, group=mpu.get_data_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
......@@ -25,6 +25,7 @@ from megatron import mpu
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
......@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
......@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = val.half()
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _HALF_TYPES):
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16Module(MegatronModule):
class Float16Module(MegatronModule):
def __init__(self, module, args):
super(Float16Module, self).__init__()
if args.fp16:
self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif args.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('should not be here')
def __init__(self, module):
super(FP16Module, self).__init__()
self.add_module('module', module.half())
self.float16_convertor = float16_convertor
def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage():
inputs = fp32_to_fp16(inputs)
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs)
outputs = float16_to_fp32(outputs)
return outputs
......
......@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule):
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
LayerNorm = import_layernorm(args.fp32_residual_connection)
LayerNorm = import_layernorm(self.fp32_residual_connection, self.bf16)
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
......@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# Self attention.
attention_output, attention_bias = \
self.self_attention(layernorm_output,
......@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
......@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
......@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule):
super(ParallelTransformer, self).__init__()
args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Store activation checkpoiting flag.
......@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_last_stage():
# Final layer norm before output.
LayerNorm = import_layernorm(args.fp32_residual_connection)
LayerNorm = import_layernorm(self.fp32_residual_connection,
self.bf16)
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
......@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
output = output.bfloat16()
else:
output = hidden_states
if get_key_value:
......
......@@ -20,7 +20,7 @@ from megatron import get_args
from megatron.model import import_layernorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules):
......@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules):
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
......@@ -67,24 +67,45 @@ def get_megatron_optimizer(model):
momentum=args.sgd_momentum)
else:
raise Exception('{} optimizer is not supported.'.format(
args.optimizer))
args.optimizer))
if args.fp16:
# Determine whether the params have main-grad field.
params_have_main_grad = False
if args.DDP_impl == 'local':
params_have_main_grad = True
if args.fp16 or args.bf16:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale.
else:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
if args.fp16:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
# Megatron optimizer.
return FP16OptimizerWithFP16Params(optimizer, grad_scaler,
args.clip_grad, args.log_num_zeros_in_grad)
return Float16OptimizerWithFloat16Params(optimizer,
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.bf16,
grad_scaler)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad, args.log_num_zeros_in_grad)
return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad)
This diff is collapsed.
......@@ -37,9 +37,8 @@ from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import FP16Module
from megatron.model import Float16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
......@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
......@@ -222,8 +222,18 @@ def get_model(model_provider_func):
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = [FP16Module(model_module) for model_module in model]
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if args.bf16 and args.fp32_residual_connection:
from megatron.model import import_layernorm
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
for model_ in model:
for module_ in model_.modules():
if isinstance(module_, LayerNorm):
module_.float()
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
......@@ -231,8 +241,12 @@ def get_model(model_provider_func):
process_group=mpu.get_data_parallel_group())
for model_module in model]
return model
if args.DDP_impl == 'local':
model = [LocalDDP(model_module) for model_module in model]
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_ddp)
for model_module in model]
return model
raise NotImplementedError('Unknown DDP implementation specified: {}. '
......@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func):
model = get_model(model_provider_func)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, FP16Module))
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer)
......@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func):
args.iteration = 0
# We only support local DDP with multiple micro-batches.
if len(model) > 1:
assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1:
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers
......@@ -331,7 +343,11 @@ def train_step(forward_step_func, data_iterator,
timers = get_timers()
# Set grad to zero.
optimizer.zero_grad()
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
for partition in model:
partition.zero_grad_buffer()
else:
optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
......@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator,
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
......@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator,
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, FP16Module))
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad,
group=mpu.get_embedding_group())
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Update parameters.
......
......@@ -48,13 +48,20 @@ def unwrap_model(model, module_instances=(torchDDP)):
def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params.
params_data = []
for param in model.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
params_data.append(param.data)
for model_ in model:
for param in model_.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
if args.bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment