Commit 1a2cb60c authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'bfloat_with_fp32_grad_acc' into 'main'

Bfloat with fp32 grad acc

See merge request ADLR/megatron-lm!247
parents 87b8b9dc b4bc51b1
...@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={},
# Parameters dtype. # Parameters dtype.
args.params_dtype = torch.float args.params_dtype = torch.float
if args.fp16: if args.fp16:
assert not args.bf16
args.params_dtype = torch.half args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# No fusion is support for bfloat for now
assert not args.masked_softmax_fusion
assert not args.bias_gelu_fusion
assert not args.bias_dropout_fusion
if args.rank == 0: if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype), print('using {} for parameters ...'.format(args.params_dtype),
flush=True) flush=True)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True
if args.dataloader_type is None: if args.dataloader_type is None:
args.dataloader_type = 'single' args.dataloader_type = 'single'
...@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={},
if args.fp16_lm_cross_entropy: if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection: if args.fp32_residual_connection:
assert args.fp16, \ assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.checkpoint_activations, \
...@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser): ...@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser):
group.add_argument('--fp16', action='store_true', group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.') help='Run model in fp16 mode.')
group.add_argument('--bf16', action='store_true',
help='Run model in bfloat16 mode.')
group.add_argument('--loss-scale', type=float, default=None, group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 ' help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic' 'values can improve fp16 convergence. If None, dynamic'
...@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser): ...@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser):
help='Run attention masking and softmax in fp32. ' help='Run attention masking and softmax in fp32. '
'This flag is ignored unless ' 'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.') '--no-query-key-layer-scaling is specified.')
group.add_argument('--fp32-allreduce', action='store_true', group.add_argument('--accumulate-allreduce-grads-in-fp32',
help='All-reduce in fp32') action='store_true',
help='Gradient accumulation and all-reduce in fp32.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true', group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation' help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.') 'for lm head to fp16.')
...@@ -577,6 +595,9 @@ def _add_distributed_args(parser): ...@@ -577,6 +595,9 @@ def _add_distributed_args(parser):
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
'to use.') 'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true',
help='If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.' )
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline', help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline') dest='scatter_gather_tensors_in_pipeline')
......
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
_LAYER_NORM = None _LAYER_NORM = None
def import_layernorm(fp32_residual_connection): def import_layernorm(fp32_residual_connection, bf16):
global _LAYER_NORM global _LAYER_NORM
if not _LAYER_NORM: if not _LAYER_NORM:
if fp32_residual_connection: if bf16:
from torch.nn import LayerNorm
elif fp32_residual_connection:
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
else: else:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
...@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel, ...@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel,
GPTModelIntermediateStage, GPTModelIntermediateStage,
GPTModelLastStage) GPTModelLastStage)
from .language_model import get_language_model from .language_model import get_language_model
from .module import FP16Module from .module import Float16Module
...@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule): ...@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
......
...@@ -13,100 +13,206 @@ ...@@ -13,100 +13,206 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from abc import ABC
from abc import abstractmethod
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
class DistributedDataParallel(MegatronModule):
def __init__(self, module): class MemoryBuffer:
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False def __init__(self, numel, dtype):
self.numel = numel
self.dtype = dtype
self.data = torch.zeros(self.numel,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
def zero(self):
"""Reset the buffer to zero."""
self.data.zero_()
def get(self, shape, start_index):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index = start_index + shape.numel()
assert end_index <= self.numel, \
'requested tensor is out of the buffer range.'
buffer_tensor = self.data[start_index:end_index]
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
class DistributedDataParallelBase(MegatronModule, ABC):
"""Abstract class for DDP."""
def __init__(self, module):
super(DistributedDataParallelBase, self).__init__()
# Keep a pointer to the model.
self.module = module self.module = module
self.data_parallel_group = mpu.get_data_parallel_group()
def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
if(self.needs_reduction): @abstractmethod
self.needs_reduction = False def allreduce_gradients(self):
buckets = {} pass
for name, param in self.module.named_parameters():
if param.requires_grad and param.grad is not None:
tp = (param.data.type())
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if self.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
self.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce:
coalesced = coalesced.float()
if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group)
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
self.hook_handles = []
self.hooks = []
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
self.needs_reduction = True
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, destination=None, prefix='', keep_vars=False):
#[h.remove() for h in self.hook_handles] return self.module.state_dict(destination, prefix, keep_vars)
sd = self.module.state_dict(destination, prefix, keep_vars)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return sd
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix, return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars) keep_vars)
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict) self.module.load_state_dict(state_dict, strict=strict)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers()) class DistributedDataParallel(DistributedDataParallelBase):
if len(buffers) > 0: """DDP with contiguous buffers options to storre and accumulate gradients.
# cross-node buffer sync This class:
flat_buffers = _flatten_dense_tensors(buffers) - has the potential to reduce memory fragmentation.
dist.broadcast(flat_buffers, 0) - provides the option to do the gradient accumulation
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): in a type other than the params type (for example fp32)
Arguments:
module: input model.
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
and the gradient all-reduce all in in float32. If this option is
true, we require `use_contiguous_buffers` to be true too.
use_contiguous_buffers: if true, use a contiguous buffer to store the
gradients.
"""
def __init__(self, module,
accumulate_allreduce_grads_in_fp32,
use_contiguous_buffers):
super(DistributedDataParallel, self).__init__(module)
self.accumulate_allreduce_grads_in_fp32 \
= accumulate_allreduce_grads_in_fp32
self.use_contiguous_buffers = use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if self.accumulate_allreduce_grads_in_fp32:
assert self.use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self._grad_buffers = None
if self.use_contiguous_buffers:
self._grad_buffers = {}
# Simple function to define buffer type.
def _get_buffer_type(param):
return torch.float if \
self.accumulate_allreduce_grads_in_fp32 else param.dtype
# First calculate total number of elements per type.
type_num_elements = {}
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
+ param.data.nelement()
# Allocate the buffer.
for dtype, num_elements in type_num_elements.items():
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self.grad_accs = []
# Loop over all the parameters in the model.
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator functtion.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc)
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad.data is not None:
param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory.
param.grad = None
return param_hook
def zero_grad_buffer(self):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert self._grad_buffers is not None, 'buffers are not initialized.'
for _, buffer_ in self._grad_buffers.items():
buffer_.zero()
def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if self._grad_buffers is not None:
for _, buffer_ in self._grad_buffers.items():
buffer_.data /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
buffer_.data, group=mpu.get_data_parallel_group())
else:
# Otherwise, bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
coalesced, group=mpu.get_data_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced) buf.copy_(synced)
def train(self, mode=True):
# Clear NCCL communicator and CUDA event cache of the default group ID,
# These cache will be recreated at the later call. This is currently a
# work-around for a potential NCCL deadlock.
if dist._backend == dist.dist_backend.NCCL:
dist._clear_group_cache()
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
...@@ -25,6 +25,7 @@ from megatron import mpu ...@@ -25,6 +25,7 @@ from megatron import mpu
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) _HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
...@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module): ...@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
"this needs to be handled manually. If you are training " "this needs to be handled manually. If you are training "
"something is definitely wrong.") "something is definitely wrong.")
def conversion_helper(val, conversion): def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` """Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure.""" #is a nested tuple/list structure."""
...@@ -120,44 +122,56 @@ def conversion_helper(val, conversion): ...@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
return rtn return rtn
def fp32_to_fp16(val): def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16""" """Convert fp32 `val` to fp16/bf16"""
def half_conversion(val): def half_conversion(val):
val_typecheck = val val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)): if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES): if isinstance(val_typecheck, _FLOAT_TYPES):
val = val.half() val = float16_convertor(val)
return val return val
return conversion_helper(val, half_conversion) return conversion_helper(val, half_conversion)
def fp16_to_fp32(val): def float16_to_fp32(val):
"""Convert fp16 `val` to fp32""" """Convert fp16/bf16 `val` to fp32"""
def float_conversion(val): def float_conversion(val):
val_typecheck = val val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)): if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data val_typecheck = val.data
if isinstance(val_typecheck, _HALF_TYPES): if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float() val = val.float()
return val return val
return conversion_helper(val, float_conversion) return conversion_helper(val, float_conversion)
class FP16Module(MegatronModule): class Float16Module(MegatronModule):
def __init__(self, module, args):
super(Float16Module, self).__init__()
def __init__(self, module): if args.fp16:
super(FP16Module, self).__init__()
self.add_module('module', module.half()) self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif args.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('should not be here')
self.float16_convertor = float16_convertor
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
inputs = fp32_to_fp16(inputs) inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs) outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs) outputs = float16_to_fp32(outputs)
return outputs return outputs
......
...@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule):
self.apply_residual_connection_post_layernorm \ self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data. # Layernorm on the input data.
LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(self.fp32_residual_connection, self.bf16)
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.self_attention(layernorm_output, self.self_attention(layernorm_output,
...@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \ attention_output, attention_bias = \
...@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention # Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input) layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# MLP. # MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output) mlp_output, mlp_bias = self.mlp(layernorm_output)
...@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule): ...@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
...@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule): ...@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# Final layer norm before output. # Final layer norm before output.
LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(self.fp32_residual_connection,
self.bf16)
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule): ...@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
output = output.bfloat16()
else: else:
output = hidden_states output = hidden_states
if get_key_value: if get_key_value:
......
...@@ -20,7 +20,7 @@ from megatron import get_args ...@@ -20,7 +20,7 @@ from megatron import get_args
from megatron.model import import_layernorm from megatron.model import import_layernorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules): def _get_params_for_weight_decay_optimization(modules):
...@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules): ...@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules):
Layernorms and baises will have no weight decay but the rest will. Layernorms and baises will have no weight decay but the rest will.
""" """
args = get_args() args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
...@@ -69,12 +69,26 @@ def get_megatron_optimizer(model): ...@@ -69,12 +69,26 @@ def get_megatron_optimizer(model):
raise Exception('{} optimizer is not supported.'.format( raise Exception('{} optimizer is not supported.'.format(
args.optimizer)) args.optimizer))
if args.fp16: # Determine whether the params have main-grad field.
params_have_main_grad = False
if args.DDP_impl == 'local':
params_have_main_grad = True
if args.fp16 or args.bf16:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale. # Constant loss scale.
if args.loss_scale: if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale) grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale. # Dynamic loss scale.
else: else:
if args.fp16:
grad_scaler = DynamicGradScaler( grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale, initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale, min_scale=args.min_loss_scale,
...@@ -82,9 +96,16 @@ def get_megatron_optimizer(model): ...@@ -82,9 +96,16 @@ def get_megatron_optimizer(model):
backoff_factor=0.5, backoff_factor=0.5,
growth_interval=args.loss_scale_window, growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis) hysteresis=args.hysteresis)
# Megatron optimizer. # Megatron optimizer.
return FP16OptimizerWithFP16Params(optimizer, grad_scaler, return Float16OptimizerWithFloat16Params(optimizer,
args.clip_grad, args.log_num_zeros_in_grad) args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.bf16,
grad_scaler)
# FP32. # FP32.
return FP32Optimizer(optimizer, args.clip_grad, args.log_num_zeros_in_grad) return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad)
...@@ -46,24 +46,37 @@ def _zero_grad_group_helper(group, set_to_none): ...@@ -46,24 +46,37 @@ def _zero_grad_group_helper(group, set_to_none):
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""Use multi-tensor-applier to copy values from one list to another.""" """Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
if overflow_buf: if overflow_buf:
overflow_buf.fill_(0) overflow_buf.fill_(0)
else:
overflow_buf = torch.cuda.IntTensor([0])
# Scaling with factor `1.0` is equivalent to copy. # Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale, multi_tensor_applier(amp_C.multi_tensor_scale,
overflow_buf, overflow_buf,
[this, that], [this, that],
1.0) 1.0)
else:
for this_, that_ in zip(this, that):
that_.copy_(this_)
class MegatronOptimizer(ABC): class MegatronOptimizer(ABC):
def __init__(self, optimizer):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
"""Input optimizer is the base optimizer for example Adam.""" """Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.' assert self.optimizer, 'no optimizer is provided.'
# Set gradient clipping and logging params.
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad
def get_parameters(self): def get_parameters(self):
params = [] params = []
...@@ -72,31 +85,38 @@ class MegatronOptimizer(ABC): ...@@ -72,31 +85,38 @@ class MegatronOptimizer(ABC):
params.append(param) params.append(param)
return params return params
def clip_grad_norm(self, clip_grad): def clip_grad_norm(self, clip_grad):
params = self.get_parameters() params = self.get_parameters()
return clip_grad_norm_fp32(params, clip_grad) return clip_grad_norm_fp32(params, clip_grad)
def count_zeros(self): def count_zeros(self):
params = self.get_parameters() params = self.get_parameters()
return count_zeros_fp32(params) return count_zeros_fp32(params)
@abstractmethod @abstractmethod
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
pass pass
@abstractmethod @abstractmethod
def get_loss_scale(self): def get_loss_scale(self):
"""The output should be a cuda tensor of size 1.""" """The output should be a cuda tensor of size 1."""
pass pass
def scale_loss(self, loss): def scale_loss(self, loss):
"""Simple scaling.""" """Simple scaling."""
return self.get_loss_scale() * loss return self.get_loss_scale() * loss
@abstractmethod @abstractmethod
def step(self): def step(self):
pass pass
@abstractmethod @abstractmethod
def reload_model_params(self): def reload_model_params(self):
"""Refreshes any internal state from the current model parameters. """Refreshes any internal state from the current model parameters.
...@@ -106,14 +126,17 @@ class MegatronOptimizer(ABC): ...@@ -106,14 +126,17 @@ class MegatronOptimizer(ABC):
with main parameters, the main parameters need to also be updated.""" with main parameters, the main parameters need to also be updated."""
pass pass
@abstractmethod @abstractmethod
def state_dict(self): def state_dict(self):
pass pass
@abstractmethod @abstractmethod
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
pass pass
# Promote state so it can be retrieved or set via # Promote state so it can be retrieved or set via
# "optimizer_instance.state" # "optimizer_instance.state"
def _get_state(self): def _get_state(self):
...@@ -124,6 +147,7 @@ class MegatronOptimizer(ABC): ...@@ -124,6 +147,7 @@ class MegatronOptimizer(ABC):
state = property(_get_state, _set_state) state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via # Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups" # "optimizer_instance.param_groups"
# (for example, to adjust the learning rate) # (for example, to adjust the learning rate)
...@@ -137,50 +161,90 @@ class MegatronOptimizer(ABC): ...@@ -137,50 +161,90 @@ class MegatronOptimizer(ABC):
class FP16OptimizerWithFP16Params(MegatronOptimizer): class Float16OptimizerWithFloat16Params(MegatronOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
def __init__(self, optimizer, grad_scaler, clip_grad, log_num_zeros_in_grad):
super(FP16OptimizerWithFP16Params, self).__init__(optimizer) Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
self.bf16 = bf16
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
self.clip_grad = clip_grad # None grad scaler is only supported for bf16.
self.log_num_zeros_in_grad = log_num_zeros_in_grad if self.grad_scaler is None:
assert self.bf16, 'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend. # Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan. # Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if self.grad_scaler:
self.found_inf = torch.cuda.FloatTensor([0.0]) self.found_inf = torch.cuda.FloatTensor([0.0])
# Dummy tensor needed for apex multi-apply tensor. # Dummy tensor needed for apex multi-apply tensor.
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if bf16:
self._dummy_overflow_buf = None
else:
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# In case grad scaler is not passed, define the unity scale.
if self.grad_scaler is None:
self._scale_one = torch.cuda.FloatTensor([1.0])
# ====================== # ======================
# main parameter stuff # main parameter stuff
# ====================== # ======================
# Three groups of parameters: # Three groups of parameters:
# fp16_groups: original fp16 parameters # float16_groups: original float16 parameters
# fp32_from_fp16_groups: fp32 copy of fp16 parameters # fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters # fp32_from_fp32_groups: original fp32 parameters
self.fp16_groups = [] self.float16_groups = []
self.fp32_from_fp16_groups = [] self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = [] self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer: # For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
fp16_params_this_group = [] float16_params_this_group = []
fp32_params_this_group = [] fp32_params_this_group = []
fp32_from_fp16_params_this_group = [] fp32_from_float16_params_this_group = []
# For all the parameters in this group: # For all the parameters in this group:
for i, param in enumerate(param_group['params']): for i, param in enumerate(param_group['params']):
if param.requires_grad: if param.requires_grad:
# fp16 params: # float16 params:
if param.type() == 'torch.cuda.HalfTensor': if param.type() in ['torch.cuda.HalfTensor',
fp16_params_this_group.append(param) 'torch.cuda.BFloat16Tensor']:
float16_params_this_group.append(param)
# Create a copy # Create a copy
main_param = param.detach().clone().float() main_param = param.detach().clone().float()
# Store grads
main_param.requires_grad = True
# Copy tensor model parallel attributes. # Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param, mpu.copy_tensor_model_parallel_attributes(main_param,
param) param)
...@@ -188,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -188,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
main_param.shared = param.shared main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy. # Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param param_group['params'][i] = main_param
fp32_from_fp16_params_this_group.append(main_param) fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param. # Reset existing state dict key to the new main param.
if param in self.optimizer.state: if param in self.optimizer.state:
self.optimizer.state[main_param] \ self.optimizer.state[main_param] \
...@@ -200,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -200,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
param_group['params'][i] = param param_group['params'][i] = param
else: else:
raise TypeError("Wrapped parameters must be either " raise TypeError('Wrapped parameters must be one of '
"torch.cuda.FloatTensor or " 'torch.cuda.FloatTensor, '
"torch.cuda.HalfTensor. " 'torch.cuda.HalfTensor, or '
"Received {}".format(param.type())) 'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type()))
self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(
fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to # Leverage state_dict() and load_state_dict() to
...@@ -216,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -216,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
fp16_groups & fp32_from_fp32_groups.""" float16_groups & fp32_from_fp32_groups."""
for group in self.fp16_groups: for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups: for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
def get_loss_scale(self): def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale return self.grad_scaler.scale
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the fp16 group. # This only needs to be done for the float16 group.
model_grads = [] for model_group, main_group in zip(self.float16_groups,
main_grads = [] self.fp32_from_float16_groups):
for model_group, main_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
if self.params_have_main_grad:
main_param.grad = model_param.main_grad.float()
else:
if model_param.grad is not None: if model_param.grad is not None:
if main_param.grad is None: main_param.grad = model_param.grad.float()
main_param.grad = torch.empty_like(main_param) # For fp32 grads, we need to reset the grads to main grad.
model_grads.append(model_param.grad.data) if self.params_have_main_grad:
main_grads.append(main_param.grad.data) for model_group in self.fp32_from_fp32_groups:
_multi_tensor_copy_this_to_that(this=model_grads, that=main_grads, for model_param in model_group:
overflow_buf=self._dummy_overflow_buf) model_param.grad = model_param.main_grad
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
main_grads = [] main_grads = []
# fp32 params fromm fp16 ones. # fp32 params fromm float16 ones.
for main_group in self.fp32_from_fp16_groups: for main_group in self.fp32_from_float16_groups:
for main_param in main_group: for main_param in main_group:
if main_param.grad is not None: if main_param.grad is not None:
main_grads.append(main_param.grad.data) main_grads.append(main_param.grad.data)
...@@ -270,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -270,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return found_inf_flag return found_inf_flag
def _get_model_and_main_params_data_fp16(self): def _get_model_and_main_params_data_float16(self):
model_data = [] model_data = []
main_data = [] main_data = []
for model_group, main_group in zip(self.fp16_groups, for model_group, main_group in zip(self.float16_groups,
self.fp32_from_fp16_groups): self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data) model_data.append(model_param.data)
main_data.append(main_param.data) main_data.append(main_param.data)
...@@ -282,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -282,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def _copy_main_params_to_model_params(self): def _copy_main_params_to_model_params(self):
# Only needed for the fp16 params. # Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_fp16() model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data, _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf) overflow_buf=self._dummy_overflow_buf)
def _copy_model_params_to_main_params(self): def _copy_model_params_to_main_params(self):
# Only needed for the fp16 params. # Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_fp16() model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=model_data, that=main_data, _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
overflow_buf=self._dummy_overflow_buf) overflow_buf=self._dummy_overflow_buf)
...@@ -298,6 +367,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -298,6 +367,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def reload_model_params(self): def reload_model_params(self):
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
@torch.no_grad() @torch.no_grad()
def step(self): def step(self):
...@@ -308,6 +378,10 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -308,6 +378,10 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self._copy_model_grads_to_main_grads() self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
# Unscale and check for inf/nan. # Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start() timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_main_grads_and_check_for_nan() found_inf_flag = self._unscale_main_grads_and_check_for_nan()
...@@ -329,7 +403,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -329,7 +403,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers('optimizer-clip-main-grad').stop() timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads # count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Step the optimizer. # Step the optimizer.
self.optimizer.step() self.optimizer.step()
...@@ -346,8 +421,9 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -346,8 +421,9 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict() state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict() state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_fp16_groups state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
return state_dict return state_dict
...@@ -365,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -365,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
print_rank_0('***WARNING*** found an old checkpoint, will not ' print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...') 'load grad scaler ...')
else: else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler']) self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
print_rank_0('***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...')
# Copy data for the main params. # Copy data for the main params.
fp32_from_fp16_params_key = 'fp32_from_fp16_params' fp32_from_float16_params_key = 'fp32_from_fp16_params'
if fp32_from_fp16_params_key not in state_dict: if fp32_from_float16_params_key not in state_dict:
fp32_from_fp16_params_key = 'fp32_from_fp16' fp32_from_float16_params_key = 'fp32_from_fp16'
for current_group, saved_group in zip( for current_group, saved_group in zip(
self.fp32_from_fp16_groups, self.fp32_from_float16_groups,
state_dict[fp32_from_fp16_params_key]): state_dict[fp32_from_float16_params_key]):
for current_param, saved_param in zip(current_group, saved_group): for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data) current_param.data.copy_(saved_param.data)
...@@ -381,11 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -381,11 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
class FP32Optimizer(MegatronOptimizer): class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad): def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
super(FP32Optimizer, self).__init__(optimizer)
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self._scale = torch.cuda.FloatTensor([1.0]) self._scale = torch.cuda.FloatTensor([1.0])
...@@ -405,13 +489,20 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -405,13 +489,20 @@ class FP32Optimizer(MegatronOptimizer):
"""Clip gradients (if needed) and step the base optimizer. """Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow.""" Always return successful since there is no overflow."""
# Copy main_grads to grads.
if self.params_have_main_grad:
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
param.grad = param.main_grad
# Clip gradients. # Clip gradients.
grad_norm = None grad_norm = None
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad) grad_norm = self.clip_grad_norm(self.clip_grad)
# count the zeros in the grads # count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Update parameters. # Update parameters.
self.optimizer.step() self.optimizer.step()
......
...@@ -37,9 +37,8 @@ from megatron import print_rank_0 ...@@ -37,9 +37,8 @@ from megatron import print_rank_0
from megatron import print_rank_last from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model import FP16Module from megatron.model import Float16Module
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
...@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving ...@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory from megatron.utils import report_memory
def print_datetime(string): def print_datetime(string):
"""Note that this call will sync across all ranks.""" """Note that this call will sync across all ranks."""
torch.distributed.barrier() torch.distributed.barrier()
...@@ -222,8 +222,18 @@ def get_model(model_provider_func): ...@@ -222,8 +222,18 @@ def get_model(model_provider_func):
model_module.cuda(torch.cuda.current_device()) model_module.cuda(torch.cuda.current_device())
# Fp16 conversion. # Fp16 conversion.
if args.fp16: if args.fp16 or args.bf16:
model = [FP16Module(model_module) for model_module in model] model = [Float16Module(model_module, args) for model_module in model]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if args.bf16 and args.fp32_residual_connection:
from megatron.model import import_layernorm
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
for model_ in model:
for module_ in model_.modules():
if isinstance(module_, LayerNorm):
module_.float()
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
...@@ -231,8 +241,12 @@ def get_model(model_provider_func): ...@@ -231,8 +241,12 @@ def get_model(model_provider_func):
process_group=mpu.get_data_parallel_group()) process_group=mpu.get_data_parallel_group())
for model_module in model] for model_module in model]
return model return model
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
model = [LocalDDP(model_module) for model_module in model] model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_ddp)
for model_module in model]
return model return model
raise NotImplementedError('Unknown DDP implementation specified: {}. ' raise NotImplementedError('Unknown DDP implementation specified: {}. '
...@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func):
model = get_model(model_provider_func) model = get_model(model_provider_func)
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, FP16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model) optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer) lr_scheduler = get_learning_rate_scheduler(optimizer)
...@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func):
args.iteration = 0 args.iteration = 0
# We only support local DDP with multiple micro-batches. # We only support local DDP with multiple micro-batches.
if len(model) > 1: if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers # get model without FP16 and/or TorchDDP wrappers
...@@ -331,6 +343,10 @@ def train_step(forward_step_func, data_iterator, ...@@ -331,6 +343,10 @@ def train_step(forward_step_func, data_iterator,
timers = get_timers() timers = get_timers()
# Set grad to zero. # Set grad to zero.
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
for partition in model:
partition.zero_grad_buffer()
else:
optimizer.zero_grad() optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
...@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator,
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
for model_module in model: for model_module in model:
model_module.allreduce_params(reduce_after=False, model_module.allreduce_gradients()
fp32_allreduce=args.fp32_allreduce)
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure # All-reduce word_embeddings' grad across first and last stages to ensure
...@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator, ...@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator,
elif mpu.is_pipeline_last_stage(ignore_virtual=True): elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1] unwrapped_model = model[-1]
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, FP16Module)) unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings: if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight() word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad, if args.DDP_impl == 'local':
group=mpu.get_embedding_group()) grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
# Update parameters. # Update parameters.
......
...@@ -48,12 +48,19 @@ def unwrap_model(model, module_instances=(torchDDP)): ...@@ -48,12 +48,19 @@ def unwrap_model(model, module_instances=(torchDDP)):
def calc_params_l2_norm(model): def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """ """Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params. # Remove duplicate params.
params_data = [] params_data = []
for param in model.parameters(): for model_ in model:
for param in model_.parameters():
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate: if is_not_shared and is_not_tp_duplicate:
if args.bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data) params_data.append(param.data)
# Calculate norm # Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment