"csrc/vscode:/vscode.git/clone" did not exist on "9ccee9c051cfabcdf2919fa2c1f69c11a72bf23d"
Commit b4bc51b1 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Jared Casper
Browse files

Bfloat with fp32 grad acc

parent 87b8b9dc
...@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={},
# Parameters dtype. # Parameters dtype.
args.params_dtype = torch.float args.params_dtype = torch.float
if args.fp16: if args.fp16:
assert not args.bf16
args.params_dtype = torch.half args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# No fusion is support for bfloat for now
assert not args.masked_softmax_fusion
assert not args.bias_gelu_fusion
assert not args.bias_dropout_fusion
if args.rank == 0: if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype), print('using {} for parameters ...'.format(args.params_dtype),
flush=True) flush=True)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True
if args.dataloader_type is None: if args.dataloader_type is None:
args.dataloader_type = 'single' args.dataloader_type = 'single'
...@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={},
if args.fp16_lm_cross_entropy: if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection: if args.fp32_residual_connection:
assert args.fp16, \ assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.checkpoint_activations, \
...@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser): ...@@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser):
group.add_argument('--fp16', action='store_true', group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.') help='Run model in fp16 mode.')
group.add_argument('--bf16', action='store_true',
help='Run model in bfloat16 mode.')
group.add_argument('--loss-scale', type=float, default=None, group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 ' help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic' 'values can improve fp16 convergence. If None, dynamic'
...@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser): ...@@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser):
help='Run attention masking and softmax in fp32. ' help='Run attention masking and softmax in fp32. '
'This flag is ignored unless ' 'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.') '--no-query-key-layer-scaling is specified.')
group.add_argument('--fp32-allreduce', action='store_true', group.add_argument('--accumulate-allreduce-grads-in-fp32',
help='All-reduce in fp32') action='store_true',
help='Gradient accumulation and all-reduce in fp32.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true', group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation' help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.') 'for lm head to fp16.')
...@@ -577,6 +595,9 @@ def _add_distributed_args(parser): ...@@ -577,6 +595,9 @@ def _add_distributed_args(parser):
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
'to use.') 'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true',
help='If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.' )
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline', help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline') dest='scatter_gather_tensors_in_pipeline')
......
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
_LAYER_NORM = None _LAYER_NORM = None
def import_layernorm(fp32_residual_connection): def import_layernorm(fp32_residual_connection, bf16):
global _LAYER_NORM global _LAYER_NORM
if not _LAYER_NORM: if not _LAYER_NORM:
if fp32_residual_connection: if bf16:
from torch.nn import LayerNorm
elif fp32_residual_connection:
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
else: else:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
...@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel, ...@@ -39,6 +41,6 @@ from .gpt_model import (GPTModel,
GPTModelIntermediateStage, GPTModelIntermediateStage,
GPTModelLastStage) GPTModelLastStage)
from .language_model import get_language_model from .language_model import get_language_model
from .module import FP16Module from .module import Float16Module
...@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule): ...@@ -78,7 +78,7 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
......
...@@ -13,100 +13,206 @@ ...@@ -13,100 +13,206 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from abc import ABC
from abc import abstractmethod
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
class DistributedDataParallel(MegatronModule):
def __init__(self, module): class MemoryBuffer:
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False def __init__(self, numel, dtype):
self.numel = numel
self.dtype = dtype
self.data = torch.zeros(self.numel,
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
def zero(self):
"""Reset the buffer to zero."""
self.data.zero_()
def get(self, shape, start_index):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index = start_index + shape.numel()
assert end_index <= self.numel, \
'requested tensor is out of the buffer range.'
buffer_tensor = self.data[start_index:end_index]
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
class DistributedDataParallelBase(MegatronModule, ABC):
"""Abstract class for DDP."""
def __init__(self, module):
super(DistributedDataParallelBase, self).__init__()
# Keep a pointer to the model.
self.module = module self.module = module
self.data_parallel_group = mpu.get_data_parallel_group()
def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False): @abstractmethod
if(self.needs_reduction): def allreduce_gradients(self):
self.needs_reduction = False pass
buckets = {}
for name, param in self.module.named_parameters():
if param.requires_grad and param.grad is not None:
tp = (param.data.type())
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if self.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
self.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce:
coalesced = coalesced.float()
if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group)
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
self.hook_handles = []
self.hooks = []
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
self.needs_reduction = True
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, destination=None, prefix='', keep_vars=False):
#[h.remove() for h in self.hook_handles] return self.module.state_dict(destination, prefix, keep_vars)
sd = self.module.state_dict(destination, prefix, keep_vars)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return sd
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix, return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars) keep_vars)
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict) self.module.load_state_dict(state_dict, strict=strict)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers()) class DistributedDataParallel(DistributedDataParallelBase):
if len(buffers) > 0: """DDP with contiguous buffers options to storre and accumulate gradients.
# cross-node buffer sync This class:
flat_buffers = _flatten_dense_tensors(buffers) - has the potential to reduce memory fragmentation.
dist.broadcast(flat_buffers, 0) - provides the option to do the gradient accumulation
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): in a type other than the params type (for example fp32)
buf.copy_(synced)
def train(self, mode=True): Arguments:
# Clear NCCL communicator and CUDA event cache of the default group ID, module: input model.
# These cache will be recreated at the later call. This is currently a accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
# work-around for a potential NCCL deadlock. and the gradient all-reduce all in in float32. If this option is
if dist._backend == dist.dist_backend.NCCL: true, we require `use_contiguous_buffers` to be true too.
dist._clear_group_cache() use_contiguous_buffers: if true, use a contiguous buffer to store the
super(DistributedDataParallel, self).train(mode) gradients.
self.module.train(mode) """
'''
def __init__(self, module,
accumulate_allreduce_grads_in_fp32,
use_contiguous_buffers):
super(DistributedDataParallel, self).__init__(module)
self.accumulate_allreduce_grads_in_fp32 \
= accumulate_allreduce_grads_in_fp32
self.use_contiguous_buffers = use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if self.accumulate_allreduce_grads_in_fp32:
assert self.use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self._grad_buffers = None
if self.use_contiguous_buffers:
self._grad_buffers = {}
# Simple function to define buffer type.
def _get_buffer_type(param):
return torch.float if \
self.accumulate_allreduce_grads_in_fp32 else param.dtype
# First calculate total number of elements per type.
type_num_elements = {}
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
+ param.data.nelement()
# Allocate the buffer.
for dtype, num_elements in type_num_elements.items():
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self.grad_accs = []
# Loop over all the parameters in the model.
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator functtion.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc)
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad.data is not None:
param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory.
param.grad = None
return param_hook
def zero_grad_buffer(self):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert self._grad_buffers is not None, 'buffers are not initialized.'
for _, buffer_ in self._grad_buffers.items():
buffer_.zero()
def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if self._grad_buffers is not None:
for _, buffer_ in self._grad_buffers.items():
buffer_.data /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
buffer_.data, group=mpu.get_data_parallel_group())
else:
# Otherwise, bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce(
coalesced, group=mpu.get_data_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
...@@ -25,6 +25,7 @@ from megatron import mpu ...@@ -25,6 +25,7 @@ from megatron import mpu
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) _HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
...@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module): ...@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
"this needs to be handled manually. If you are training " "this needs to be handled manually. If you are training "
"something is definitely wrong.") "something is definitely wrong.")
def conversion_helper(val, conversion): def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` """Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure.""" #is a nested tuple/list structure."""
...@@ -120,44 +122,56 @@ def conversion_helper(val, conversion): ...@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
return rtn return rtn
def fp32_to_fp16(val): def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16""" """Convert fp32 `val` to fp16/bf16"""
def half_conversion(val): def half_conversion(val):
val_typecheck = val val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)): if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES): if isinstance(val_typecheck, _FLOAT_TYPES):
val = val.half() val = float16_convertor(val)
return val return val
return conversion_helper(val, half_conversion) return conversion_helper(val, half_conversion)
def fp16_to_fp32(val): def float16_to_fp32(val):
"""Convert fp16 `val` to fp32""" """Convert fp16/bf16 `val` to fp32"""
def float_conversion(val): def float_conversion(val):
val_typecheck = val val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)): if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data val_typecheck = val.data
if isinstance(val_typecheck, _HALF_TYPES): if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float() val = val.float()
return val return val
return conversion_helper(val, float_conversion) return conversion_helper(val, float_conversion)
class FP16Module(MegatronModule): class Float16Module(MegatronModule):
def __init__(self, module, args):
super(Float16Module, self).__init__()
if args.fp16:
self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif args.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('should not be here')
def __init__(self, module): self.float16_convertor = float16_convertor
super(FP16Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
inputs = fp32_to_fp16(inputs) inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs) outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs) outputs = float16_to_fp32(outputs)
return outputs return outputs
......
...@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -397,8 +397,11 @@ class ParallelTransformerLayer(MegatronModule):
self.apply_residual_connection_post_layernorm \ self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data. # Layernorm on the input data.
LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(self.fp32_residual_connection, self.bf16)
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -440,6 +443,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.self_attention(layernorm_output, self.self_attention(layernorm_output,
...@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -478,6 +483,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \ attention_output, attention_bias = \
...@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -500,6 +507,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention # Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input) layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
if self.bf16 and self.fp32_residual_connection:
layernorm_output = layernorm_output.bfloat16()
# MLP. # MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output) mlp_output, mlp_bias = self.mlp(layernorm_output)
...@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule): ...@@ -533,6 +542,7 @@ class ParallelTransformer(MegatronModule):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
...@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule): ...@@ -578,7 +588,8 @@ class ParallelTransformer(MegatronModule):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# Final layer norm before output. # Final layer norm before output.
LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(self.fp32_residual_connection,
self.bf16)
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
...@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule): ...@@ -665,6 +676,8 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
if self.bf16 and self.fp32_residual_connection:
output = output.bfloat16()
else: else:
output = hidden_states output = hidden_states
if get_key_value: if get_key_value:
......
...@@ -20,7 +20,7 @@ from megatron import get_args ...@@ -20,7 +20,7 @@ from megatron import get_args
from megatron.model import import_layernorm from megatron.model import import_layernorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules): def _get_params_for_weight_decay_optimization(modules):
...@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules): ...@@ -28,7 +28,7 @@ def _get_params_for_weight_decay_optimization(modules):
Layernorms and baises will have no weight decay but the rest will. Layernorms and baises will have no weight decay but the rest will.
""" """
args = get_args() args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
...@@ -67,24 +67,45 @@ def get_megatron_optimizer(model): ...@@ -67,24 +67,45 @@ def get_megatron_optimizer(model):
momentum=args.sgd_momentum) momentum=args.sgd_momentum)
else: else:
raise Exception('{} optimizer is not supported.'.format( raise Exception('{} optimizer is not supported.'.format(
args.optimizer)) args.optimizer))
if args.fp16: # Determine whether the params have main-grad field.
params_have_main_grad = False
if args.DDP_impl == 'local':
params_have_main_grad = True
if args.fp16 or args.bf16:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale. # Constant loss scale.
if args.loss_scale: if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale) grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale. # Dynamic loss scale.
else: else:
grad_scaler = DynamicGradScaler( if args.fp16:
initial_scale=args.initial_loss_scale, grad_scaler = DynamicGradScaler(
min_scale=args.min_loss_scale, initial_scale=args.initial_loss_scale,
growth_factor=2.0, min_scale=args.min_loss_scale,
backoff_factor=0.5, growth_factor=2.0,
growth_interval=args.loss_scale_window, backoff_factor=0.5,
hysteresis=args.hysteresis) growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
# Megatron optimizer. # Megatron optimizer.
return FP16OptimizerWithFP16Params(optimizer, grad_scaler, return Float16OptimizerWithFloat16Params(optimizer,
args.clip_grad, args.log_num_zeros_in_grad) args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.bf16,
grad_scaler)
# FP32. # FP32.
return FP32Optimizer(optimizer, args.clip_grad, args.log_num_zeros_in_grad) return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad)
This diff is collapsed.
...@@ -37,9 +37,8 @@ from megatron import print_rank_0 ...@@ -37,9 +37,8 @@ from megatron import print_rank_0
from megatron import print_rank_last from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model import FP16Module from megatron.model import Float16Module
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
...@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving ...@@ -54,6 +53,7 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory from megatron.utils import report_memory
def print_datetime(string): def print_datetime(string):
"""Note that this call will sync across all ranks.""" """Note that this call will sync across all ranks."""
torch.distributed.barrier() torch.distributed.barrier()
...@@ -222,8 +222,18 @@ def get_model(model_provider_func): ...@@ -222,8 +222,18 @@ def get_model(model_provider_func):
model_module.cuda(torch.cuda.current_device()) model_module.cuda(torch.cuda.current_device())
# Fp16 conversion. # Fp16 conversion.
if args.fp16: if args.fp16 or args.bf16:
model = [FP16Module(model_module) for model_module in model] model = [Float16Module(model_module, args) for model_module in model]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if args.bf16 and args.fp32_residual_connection:
from megatron.model import import_layernorm
LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
for model_ in model:
for module_ in model_.modules():
if isinstance(module_, LayerNorm):
module_.float()
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
...@@ -231,8 +241,12 @@ def get_model(model_provider_func): ...@@ -231,8 +241,12 @@ def get_model(model_provider_func):
process_group=mpu.get_data_parallel_group()) process_group=mpu.get_data_parallel_group())
for model_module in model] for model_module in model]
return model return model
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
model = [LocalDDP(model_module) for model_module in model] model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_ddp)
for model_module in model]
return model return model
raise NotImplementedError('Unknown DDP implementation specified: {}. ' raise NotImplementedError('Unknown DDP implementation specified: {}. '
...@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -289,7 +303,7 @@ def setup_model_and_optimizer(model_provider_func):
model = get_model(model_provider_func) model = get_model(model_provider_func)
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, FP16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model) optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer) lr_scheduler = get_learning_rate_scheduler(optimizer)
...@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -308,9 +322,7 @@ def setup_model_and_optimizer(model_provider_func):
args.iteration = 0 args.iteration = 0
# We only support local DDP with multiple micro-batches. # We only support local DDP with multiple micro-batches.
if len(model) > 1: if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers # get model without FP16 and/or TorchDDP wrappers
...@@ -331,7 +343,11 @@ def train_step(forward_step_func, data_iterator, ...@@ -331,7 +343,11 @@ def train_step(forward_step_func, data_iterator,
timers = get_timers() timers = get_timers()
# Set grad to zero. # Set grad to zero.
optimizer.zero_grad() if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
for partition in model:
partition.zero_grad_buffer()
else:
optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
...@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -351,8 +367,7 @@ def train_step(forward_step_func, data_iterator,
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
for model_module in model: for model_module in model:
model_module.allreduce_params(reduce_after=False, model_module.allreduce_gradients()
fp32_allreduce=args.fp32_allreduce)
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure # All-reduce word_embeddings' grad across first and last stages to ensure
...@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator, ...@@ -368,12 +383,15 @@ def train_step(forward_step_func, data_iterator,
elif mpu.is_pipeline_last_stage(ignore_virtual=True): elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1] unwrapped_model = model[-1]
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, FP16Module)) unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings: if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight() word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad, if args.DDP_impl == 'local':
group=mpu.get_embedding_group()) grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
# Update parameters. # Update parameters.
......
...@@ -48,13 +48,20 @@ def unwrap_model(model, module_instances=(torchDDP)): ...@@ -48,13 +48,20 @@ def unwrap_model(model, module_instances=(torchDDP)):
def calc_params_l2_norm(model): def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """ """Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params. # Remove duplicate params.
params_data = [] params_data = []
for param in model.parameters(): for model_ in model:
is_not_shared = param_is_not_shared(param) for param in model_.parameters():
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_shared = param_is_not_shared(param)
if is_not_shared and is_not_tp_duplicate: is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
params_data.append(param.data) if is_not_shared and is_not_tp_duplicate:
if args.bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm # Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier( norm, _ = multi_tensor_applier(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment