Commit 1a2cb60c authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'bfloat_with_fp32_grad_acc' into 'main'

Bfloat with fp32 grad acc

See merge request ADLR/megatron-lm!247
parents 87b8b9dc b4bc51b1
......@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={},
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# No fusion is support for bfloat for now
assert not args.masked_softmax_fusion
assert not args.bias_gelu_fusion
assert not args.bias_dropout_fusion
if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True
if args.dataloader_type is None:
args.dataloader_type = 'single'
......@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={},
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16, \
'residual connection in fp32 only supported when using fp16.'
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \
......@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser):
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.')
group.add_argument('--bf16', action='store_true',
help='Run model in bfloat16 mode.')
group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
......@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser):
help='Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.')
group.add_argument('--fp32-allreduce', action='store_true',
help='All-reduce in fp32')
group.add_argument('--accumulate-allreduce-grads-in-fp32',
action='store_true',
help='Gradient accumulation and all-reduce in fp32.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
......@@ -577,6 +595,9 @@ def _add_distributed_args(parser):
choices=['local', 'torch'],
help='which DistributedDataParallel implementation '
'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true',
help='If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.' )
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline')
......
......@@ -16,11 +16,13 @@
_LAYER_NORM = None
def import_layernorm(fp32_residual_connection):
def import_layernorm(fp32_residual_connection, bf16):
global _LAYER_NORM
if not _LAYER_NORM:
if fp32_residual_connection:
if bf16:
from torch.nn import LayerNorm
elif fp32_residual_connection:
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
else:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
......@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel,
GPTModelIntermediateStage,
GPTModelLastStage)
from .language_model import get_language_model
from .module import FP16Module
from .module import Float16Module
......@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
LayerNorm = import_layernorm(args.fp32_residual_connection)
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
......
......@@ -13,100 +13,206 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from abc import abstractmethod
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
from megatron import get_args
from megatron import mpu
from .module import MegatronModule
class DistributedDataParallel(MegatronModule):
def __init__(self, module):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
class MemoryBuffer:
def __init__(self, numel, dtype):
self.numel = numel
self.dtype = dtype
self.data = torch.zeros(self.numel,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
def zero(self):
"""Reset the buffer to zero."""
self.data.zero_()
def get(self, shape, start_index):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index = start_index + shape.numel()
assert end_index <= self.numel, \
'requested tensor is out of the buffer range.'
buffer_tensor = self.data[start_index:end_index]
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
class DistributedDataParallelBase(MegatronModule, ABC):
"""Abstract class for DDP."""
def __init__(self, module):
super(DistributedDataParallelBase, self).__init__()
# Keep a pointer to the model.
self.module = module
self.data_parallel_group = mpu.get_data_parallel_group()
def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
if(self.needs_reduction):
self.needs_reduction = False
buckets = {}
for name, param in self.module.named_parameters():
if param.requires_grad and param.grad is not None:
tp = (param.data.type())
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if self.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
self.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce:
coalesced = coalesced.float()
if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group)
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
self.hook_handles = []
self.hooks = []
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self.allreduce_params = allreduce_params
@abstractmethod
def allreduce_gradients(self):
pass
def forward(self, *inputs, **kwargs):
self.needs_reduction = True
return self.module(*inputs, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False):
#[h.remove() for h in self.hook_handles]
sd = self.module.state_dict(destination, prefix, keep_vars)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return self.module.state_dict(destination, prefix, keep_vars)
return sd
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
if len(buffers) > 0:
# cross-node buffer sync
flat_buffers = _flatten_dense_tensors(buffers)
dist.broadcast(flat_buffers, 0)
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
buf.copy_(synced)
def train(self, mode=True):
# Clear NCCL communicator and CUDA event cache of the default group ID,
# These cache will be recreated at the later call. This is currently a
# work-around for a potential NCCL deadlock.
if dist._backend == dist.dist_backend.NCCL:
dist._clear_group_cache()
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
class DistributedDataParallel(DistributedDataParallelBase):
"""DDP with contiguous buffers options to storre and accumulate gradients.
This class:
- has the potential to reduce memory fragmentation.
- provides the option to do the gradient accumulation
in a type other than the params type (for example fp32)
Arguments:
module: input model.
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
and the gradient all-reduce all in in float32. If this option is
true, we require `use_contiguous_buffers` to be true too.
use_contiguous_buffers: if true, use a contiguous buffer to store the
gradients.
"""
def __init__(self, module,
accumulate_allreduce_grads_in_fp32,
use_contiguous_buffers):
super(DistributedDataParallel, self).__init__(module)
self.accumulate_allreduce_grads_in_fp32 \
= accumulate_allreduce_grads_in_fp32
self.use_contiguous_buffers = use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if self.accumulate_allreduce_grads_in_fp32:
assert self.use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self._grad_buffers = None
if self.use_contiguous_buffers:
self._grad_buffers = {}
# Simple function to define buffer type.
def _get_buffer_type(param):
return torch.float if \
self.accumulate_allreduce_grads_in_fp32 else param.dtype
# First calculate total number of elements per type.
type_num_elements = {}
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
+ param.data.nelement()
# Allocate the buffer.
for dtype, num_elements in type_num_elements.items():
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self.grad_accs = []
# Loop over all the parameters in the model.
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator functtion.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc)
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad.data is not None:
param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory.
param.grad = None
return param_hook
def zero_grad_buffer(self):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert self._grad_buffers is not None, 'buffers are not initialized.'
for _, buffer_ in self._grad_buffers.items():
buffer_.zero()
def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if self._grad_buffers is not None:
for _, buffer_ in self._grad_buffers.items():
buffer_.data /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
buffer_.data, group=mpu.get_data_parallel_group())
else:
# Otherwise, bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
coalesced, group=mpu.get_data_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
......@@ -25,6 +25,7 @@ from megatron import mpu
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
......@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
......@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = val.half()
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _HALF_TYPES):
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16Module(MegatronModule):
class Float16Module(MegatronModule):
def __init__(self, module, args):
super(Float16Module, self).__init__()
if args.fp16:
self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif args.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('should not be here')
def __init__(self, module):
super(FP16Module, self).__init__()
self.add_module('module', module.half())
self.float16_convertor = float16_convertor
def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage():
inputs = fp32_to_fp16(inputs)
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs)
outputs = float16_to_fp32(outputs)
return outputs
......
......@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule):
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
LayerNorm = import_layernorm(args.fp32_residual_connection)
LayerNorm = import_layernorm(self.fp32_residual_connection, self.bf16)
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
......@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# Self attention.
attention_output, attention_bias = \
self.self_attention(layernorm_output,
......@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
......@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
......@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule):
super(ParallelTransformer, self).__init__()
args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Store activation checkpoiting flag.
......@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_last_stage():
# Final layer norm before output.
LayerNorm = import_layernorm(args.fp32_residual_connection)
LayerNorm = import_layernorm(self.fp32_residual_connection,
self.bf16)
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
......@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
output = output.bfloat16()
else:
output = hidden_states
if get_key_value:
......
......@@ -20,7 +20,7 @@ from megatron import get_args
from megatron.model import import_layernorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules):
......@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules):
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
......@@ -67,24 +67,45 @@ def get_megatron_optimizer(model):
momentum=args.sgd_momentum)
else:
raise Exception('{} optimizer is not supported.'.format(
args.optimizer))
args.optimizer))
if args.fp16:
# Determine whether the params have main-grad field.
params_have_main_grad = False
if args.DDP_impl == 'local':
params_have_main_grad = True
if args.fp16 or args.bf16:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale.
else:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
if args.fp16:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
# Megatron optimizer.
return FP16OptimizerWithFP16Params(optimizer, grad_scaler,
args.clip_grad, args.log_num_zeros_in_grad)
return Float16OptimizerWithFloat16Params(optimizer,
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.bf16,
grad_scaler)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad, args.log_num_zeros_in_grad)
return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad)
......@@ -46,24 +46,37 @@ def _zero_grad_group_helper(group, set_to_none):
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""Use multi-tensor-applier to copy values from one list to another."""
"""Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
if overflow_buf:
overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
overflow_buf,
[this, that],
1.0)
else:
overflow_buf = torch.cuda.IntTensor([0])
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
overflow_buf,
[this, that],
1.0)
for this_, that_ in zip(this, that):
that_.copy_(this_)
class MegatronOptimizer(ABC):
def __init__(self, optimizer):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
"""Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
# Set gradient clipping and logging params.
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad
def get_parameters(self):
params = []
......@@ -72,31 +85,38 @@ class MegatronOptimizer(ABC):
params.append(param)
return params
def clip_grad_norm(self, clip_grad):
params = self.get_parameters()
return clip_grad_norm_fp32(params, clip_grad)
def count_zeros(self):
params = self.get_parameters()
return count_zeros_fp32(params)
@abstractmethod
def zero_grad(self, set_to_none=True):
pass
@abstractmethod
def get_loss_scale(self):
"""The output should be a cuda tensor of size 1."""
pass
def scale_loss(self, loss):
"""Simple scaling."""
return self.get_loss_scale() * loss
@abstractmethod
def step(self):
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
......@@ -106,14 +126,17 @@ class MegatronOptimizer(ABC):
with main parameters, the main parameters need to also be updated."""
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def _get_state(self):
......@@ -124,6 +147,7 @@ class MegatronOptimizer(ABC):
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
......@@ -137,50 +161,90 @@ class MegatronOptimizer(ABC):
class FP16OptimizerWithFP16Params(MegatronOptimizer):
def __init__(self, optimizer, grad_scaler, clip_grad, log_num_zeros_in_grad):
super(FP16OptimizerWithFP16Params, self).__init__(optimizer)
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
self.bf16 = bf16
self.grad_scaler = grad_scaler
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
# None grad scaler is only supported for bf16.
if self.grad_scaler is None:
assert self.bf16, 'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
self.found_inf = torch.cuda.FloatTensor([0.0])
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if self.grad_scaler:
self.found_inf = torch.cuda.FloatTensor([0.0])
# Dummy tensor needed for apex multi-apply tensor.
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if bf16:
self._dummy_overflow_buf = None
else:
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# In case grad scaler is not passed, define the unity scale.
if self.grad_scaler is None:
self._scale_one = torch.cuda.FloatTensor([1.0])
# ======================
# main parameter stuff
# ======================
# Three groups of parameters:
# fp16_groups: original fp16 parameters
# fp32_from_fp16_groups: fp32 copy of fp16 parameters
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self.fp16_groups = []
self.fp32_from_fp16_groups = []
self.float16_groups = []
self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups:
fp16_params_this_group = []
float16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
fp32_from_float16_params_this_group = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# fp16 params:
if param.type() == 'torch.cuda.HalfTensor':
fp16_params_this_group.append(param)
# float16 params:
if param.type() in ['torch.cuda.HalfTensor',
'torch.cuda.BFloat16Tensor']:
float16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Store grads
main_param.requires_grad = True
# Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param,
param)
......@@ -188,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_fp16_params_this_group.append(main_param)
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] \
......@@ -200,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
param_group['params'][i] = param
else:
raise TypeError("Wrapped parameters must be either "
"torch.cuda.FloatTensor or "
"torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
raise TypeError('Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type()))
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(
fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to
......@@ -216,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
fp16_groups & fp32_from_fp32_groups."""
for group in self.fp16_groups:
float16_groups & fp32_from_fp32_groups."""
for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none)
def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale
def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the fp16 group.
model_grads = []
main_grads = []
for model_group, main_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups):
# This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
if model_param.grad is not None:
if main_param.grad is None:
main_param.grad = torch.empty_like(main_param)
model_grads.append(model_param.grad.data)
main_grads.append(main_param.grad.data)
_multi_tensor_copy_this_to_that(this=model_grads, that=main_grads,
overflow_buf=self._dummy_overflow_buf)
if self.params_have_main_grad:
main_param.grad = model_param.main_grad.float()
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
# For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups:
for model_param in model_group:
model_param.grad = model_param.main_grad
def _unscale_main_grads_and_check_for_nan(self):
main_grads = []
# fp32 params fromm fp16 ones.
for main_group in self.fp32_from_fp16_groups:
# fp32 params fromm float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
......@@ -270,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return found_inf_flag
def _get_model_and_main_params_data_fp16(self):
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups):
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
......@@ -282,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def _copy_main_params_to_model_params(self):
# Only needed for the fp16 params.
model_data, main_data = self._get_model_and_main_params_data_fp16()
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf)
def _copy_model_params_to_main_params(self):
# Only needed for the fp16 params.
model_data, main_data = self._get_model_and_main_params_data_fp16()
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=model_data, that=main_data,
overflow_buf=self._dummy_overflow_buf)
......@@ -298,6 +367,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def reload_model_params(self):
self._copy_model_params_to_main_params()
@torch.no_grad()
def step(self):
......@@ -308,18 +378,22 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None, None
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
......@@ -329,7 +403,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Step the optimizer.
self.optimizer.step()
......@@ -346,8 +421,9 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def state_dict(self):
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_fp16_groups
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
return state_dict
......@@ -365,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
print_rank_0('***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...')
# Copy data for the main params.
fp32_from_fp16_params_key = 'fp32_from_fp16_params'
if fp32_from_fp16_params_key not in state_dict:
fp32_from_fp16_params_key = 'fp32_from_fp16'
fp32_from_float16_params_key = 'fp32_from_fp16_params'
if fp32_from_float16_params_key not in state_dict:
fp32_from_float16_params_key = 'fp32_from_fp16'
for current_group, saved_group in zip(
self.fp32_from_fp16_groups,
state_dict[fp32_from_fp16_params_key]):
self.fp32_from_float16_groups,
state_dict[fp32_from_float16_params_key]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
......@@ -381,11 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
super(FP32Optimizer, self).__init__(optimizer)
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self._scale = torch.cuda.FloatTensor([1.0])
......@@ -405,13 +489,20 @@ class FP32Optimizer(MegatronOptimizer):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
# Copy main_grads to grads.
if self.params_have_main_grad:
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
param.grad = param.main_grad
# Clip gradients.
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Update parameters.
self.optimizer.step()
......
......@@ -37,9 +37,8 @@ from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import FP16Module
from megatron.model import Float16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
......@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
......@@ -222,8 +222,18 @@ def get_model(model_provider_func):
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = [FP16Module(model_module) for model_module in model]
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if args.bf16 and args.fp32_residual_connection:
from megatron.model import import_layernorm
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
for model_ in model:
for module_ in model_.modules():
if isinstance(module_, LayerNorm):
module_.float()
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
......@@ -231,8 +241,12 @@ def get_model(model_provider_func):
process_group=mpu.get_data_parallel_group())
for model_module in model]
return model
if args.DDP_impl == 'local':
model = [LocalDDP(model_module) for model_module in model]
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_ddp)
for model_module in model]
return model
raise NotImplementedError('Unknown DDP implementation specified: {}. '
......@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func):
model = get_model(model_provider_func)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, FP16Module))
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer)
......@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func):
args.iteration = 0
# We only support local DDP with multiple micro-batches.
if len(model) > 1:
assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1:
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers
......@@ -331,7 +343,11 @@ def train_step(forward_step_func, data_iterator,
timers = get_timers()
# Set grad to zero.
optimizer.zero_grad()
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
for partition in model:
partition.zero_grad_buffer()
else:
optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
......@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator,
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
......@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator,
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, FP16Module))
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad,
group=mpu.get_embedding_group())
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Update parameters.
......
......@@ -48,13 +48,20 @@ def unwrap_model(model, module_instances=(torchDDP)):
def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params.
params_data = []
for param in model.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
params_data.append(param.data)
for model_ in model:
for param in model_.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
if args.bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment