"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "de71fa59f56c89499fa2404a50c672809a9d3b0c"
Commit cfd2e216 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 13b3dca6
...@@ -103,12 +103,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -103,12 +103,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \ assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead' 'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size del args.model_parallel_size
if args.checkpoint_activations: if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform' args.checkpoint_granularity = 'full'
args.checkpoint_method = 'uniform'
if args.rank == 0: if args.rank == 0:
print('--checkpoint-activations is no longer valid, ' print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. ' 'use --checkpoint-granularity and --checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.') 'Defaulting to checkpoint-granularity=full and checkpoint-method=uniform.')
del args.checkpoint_activations del args.checkpoint_activations
# Set input defaults. # Set input defaults.
...@@ -283,18 +285,26 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -283,18 +285,26 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.tensor_model_parallel_size > 1, 'can distribute ' \ assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \ 'checkpointed activations only across tensor model ' \
'parallel groups' 'parallel groups'
assert args.activations_checkpoint_method is not None, \ assert args.checkpoint_granularity == 'full', \
'distributed checkpoint activations is only '\
'application to full checkpoint granularity'
assert args.checkpoint_method is not None, \
'for distributed checkpoint activations to work you '\ 'for distributed checkpoint activations to work you '\
'need to use a activation-checkpoint method ' 'need to use a checkpoint method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \ assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed checkpoint activations are supported for pytorch ' \ 'distributed checkpoint activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
if args.checkpoint_granularity == 'selective':
# model parallel memory optmization assert args.checkpoint_method is None, \
if args.model_parallel_memory_opt: 'checkpoint method is not yet supported for ' \
assert not args.async_tensor_model_parallel_allreduce 'selective checkpointing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optmization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
_print_args(args) _print_args(args)
return args return args
...@@ -476,30 +486,38 @@ def _add_training_args(parser): ...@@ -476,30 +486,38 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase' ' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval' 'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.') 'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training ' group.add_argument('--checkpoint-granularity', type=str, default=None,
'with larger models, sequences, and batch sizes.') choices=['full', 'selective'],
group.add_argument('--checkpoint-attention', action='store_true', help='Checkpoint activatins to allow for training '
help='Checkpoint activation to allow for training ' 'with larger models, sequences, and batch sizes. '
'with larger models, sequences, and batch sizes.') 'It is supported at two granularities 1) full: '
'whole transformer layer is reverse checkpointed, '
'2) selective: core attention part of the transformer '
'layer is reverse checkpointed.')
group.add_argument('--distribute-checkpointed-activations', group.add_argument('--distribute-checkpointed-activations',
action='store_true', action='store_true',
help='If set, distribute checkpointed activations ' help='If set, distribute checkpointed activations '
'across model parallel group.') 'across model parallel group.')
group.add_argument('--activations-checkpoint-method', type=str, default=None, group.add_argument('--checkpoint-method', type=str, default=None,
choices=['uniform', 'block'], choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of ' help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of ' 'Transformer layers and checkpoint the input activation of '
'each divided chunk, ' 'each divided chunk at specified granularity, '
'2) checkpoint the input activations of only a set number of ' '2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the ' 'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing' 'rest without any checkpointing at specified granularity'
'default) do not apply activations checkpoint to any layers') 'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1, group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each ' help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, ' 'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers ' '2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.') 'to checkpoint within each pipeline stage.')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None, group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all ' help='Total number of iterations to train over all '
'training runs. Note that either train-iters or ' 'training runs. Note that either train-iters or '
...@@ -548,8 +566,8 @@ def _add_training_args(parser): ...@@ -548,8 +566,8 @@ def _add_training_args(parser):
'This kernel supports only a set of hidden sizes. Please ' 'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden ' 'check persist_ln_hidden_sizes if your hidden '
'size is supported.') 'size is supported.')
group.add_argument('--model-parallel-memory-opt', action='store_true', group.add_argument('--sequence-parallel', action='store_true',
help='Enable model parallel memory optmization.') help='Enable sequence parallel optmization.')
group.add_argument('--no-gradient-accumulation-fusion', group.add_argument('--no-gradient-accumulation-fusion',
action='store_false', action='store_false',
help='Disable fusing gradient accumulation to weight ' help='Disable fusing gradient accumulation to weight '
......
...@@ -110,8 +110,11 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -110,8 +110,11 @@ def post_language_model_processing(lm_output, pooled_output,
binary_logits = binary_head(pooled_output) binary_logits = binary_head(pooled_output)
if lm_labels is None: if lm_labels is None:
return lm_logits, binary_logits # [s b h] => [b s h]
return lm_logits.transpose(0,1).contiguous(), binary_logits
else: else:
# [b s] => [s b]
lm_logits = lm_logits.transpose(0,1).contiguous()
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
......
...@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule): ...@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule):
pool_mask = (input_ids == self.pad_id).unsqueeze(2) pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT # Taking the representation of the [CLS] token of BERT
pooled_output = lm_output[:, 0, :] pooled_output = lm_output[0, :, :]
# Converting to float16 dtype # Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype) pooled_output = pooled_output.to(lm_output.dtype)
......
...@@ -32,15 +32,18 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -32,15 +32,18 @@ def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output, parallel_output,
fp16_lm_cross_entropy): fp16_lm_cross_entropy):
# Output. # Output. Format [s b h]
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
logit_weights, logit_weights,
parallel_output) parallel_output)
if labels is None: if labels is None:
return output # [s b h] => [b s h]
return output.transpose(0,1).contiguous()
else: else:
# [b s] => [s b]
labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert output.dtype == torch.half assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels) loss = mpu.vocab_parallel_cross_entropy(output, labels)
......
...@@ -33,11 +33,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -33,11 +33,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
args = get_args() args = get_args()
# Parallel logits. # Parallel logits.
if args.async_tensor_model_parallel_allreduce or\ if args.async_tensor_model_parallel_allreduce or\
args.model_parallel_memory_opt: args.sequence_parallel:
input_parallel = input_ input_parallel = input_
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.model_parallel_memory_opt model_parallel and not args.sequence_parallel
else: else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False async_grad_allreduce = False
...@@ -46,7 +46,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -46,7 +46,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias, input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion, args.gradient_accumulation_fusion,
async_grad_allreduce, args.model_parallel_memory_opt) async_grad_allreduce, args.sequence_parallel)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
...@@ -107,9 +107,9 @@ class Pooler(MegatronModule): ...@@ -107,9 +107,9 @@ class Pooler(MegatronModule):
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0): def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h] # hidden_states: [s, b, h]
# sequence_index: index of the token to pool. # sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :] pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled) pooled = self.dense(pooled)
pooled = torch.tanh(pooled) pooled = torch.tanh(pooled)
return pooled return pooled
...@@ -171,7 +171,7 @@ class Embedding(MegatronModule): ...@@ -171,7 +171,7 @@ class Embedding(MegatronModule):
self.tokentype_embeddings = None self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
self.model_parallel_memory_opt = args.model_parallel_memory_opt self.sequence_parallel = args.sequence_parallel
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
...@@ -214,18 +214,17 @@ class Embedding(MegatronModule): ...@@ -214,18 +214,17 @@ class Embedding(MegatronModule):
assert self.tokentype_embeddings is None assert self.tokentype_embeddings is None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float. # If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection: if self.fp32_residual_connection:
embeddings = embeddings.transpose(0, 1).contiguous().float() embeddings = embeddings.float()
# Otherwise, leave it as is.
else:
embeddings = embeddings.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt: if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
# Dropout. # Dropout.
if self.model_parallel_memory_opt: if self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
else: else:
......
...@@ -157,8 +157,11 @@ class T5Model(MegatronModule): ...@@ -157,8 +157,11 @@ class T5Model(MegatronModule):
self.word_embeddings_weight()) self.word_embeddings_weight())
if lm_labels is None: if lm_labels is None:
return lm_logits # [s b h] => [b s h]
return lm_logits.transpose(0,1).contiguous()
else: else:
# [b s] => [s b]
lm_labels = lm_lables.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy: if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Transformer.""" """Transformer."""
import math import math
import contextlib
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -27,7 +28,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax ...@@ -27,7 +28,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
_MATMUL_INPUT = None
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
...@@ -167,6 +167,8 @@ class SwitchMLP(MegatronModule): ...@@ -167,6 +167,8 @@ class SwitchMLP(MegatronModule):
class CoreAttention(MegatronModule): class CoreAttention(MegatronModule):
matmul_input = None
def __init__(self, layer_number, def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding): attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__() super(CoreAttention, self).__init__()
...@@ -180,7 +182,7 @@ class CoreAttention(MegatronModule): ...@@ -180,7 +182,7 @@ class CoreAttention(MegatronModule):
self.attention_softmax_in_fp32 = True self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.model_parallel_memory_opt = args.model_parallel_memory_opt self.sequence_parallel = args.sequence_parallel
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
...@@ -193,15 +195,6 @@ class CoreAttention(MegatronModule): ...@@ -193,15 +195,6 @@ class CoreAttention(MegatronModule):
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size) args.num_attention_heads, world_size)
global _MATMUL_INPUT
if _MATMUL_INPUT is None:
_MATMUL_INPUT = torch.empty(
args.micro_batch_size * self.num_attention_heads_per_partition,
args.seq_length,
args.seq_length,
dtype=torch.bfloat16,
device=torch.cuda.current_device())
coeff = None coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
...@@ -220,7 +213,7 @@ class CoreAttention(MegatronModule): ...@@ -220,7 +213,7 @@ class CoreAttention(MegatronModule):
# different outputs on different number of parallel partitions but # different outputs on different number of parallel partitions but
# on average it should not be partition dependent. # on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout) self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
def forward(self, query_layer, key_layer, def forward(self, query_layer, key_layer,
value_layer, attention_mask): value_layer, attention_mask):
...@@ -241,20 +234,18 @@ class CoreAttention(MegatronModule): ...@@ -241,20 +234,18 @@ class CoreAttention(MegatronModule):
key_layer = key_layer.view(output_size[3], key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk] # preallocting input tensor: [b * np, sq, sk]
#matmul_result = torch.empty( if CoreAttention.matmul_input is None:
# output_size[0]*output_size[1], CoreAttention.matmul_input = torch.empty(
# output_size[2], output_size[0]*output_size[1],
# output_size[3], output_size[2],
# dtype=query_layer.dtype, output_size[3],
# device=torch.cuda.current_device()) dtype=query_layer.dtype,
device=torch.cuda.current_device())
global _MATMUL_INPUT
matmul_input = _MATMUL_INPUT
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
matmul_input, CoreAttention.matmul_input,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor)) beta=0.0, alpha=(1.0/self.norm_factor))
...@@ -273,7 +264,7 @@ class CoreAttention(MegatronModule): ...@@ -273,7 +264,7 @@ class CoreAttention(MegatronModule):
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
if not self.model_parallel_memory_opt: if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
else: else:
...@@ -334,8 +325,6 @@ class ParallelAttention(MegatronModule): ...@@ -334,8 +325,6 @@ class ParallelAttention(MegatronModule):
self.attention_type = attention_type self.attention_type = attention_type
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype self.params_dtype = args.params_dtype
self.checkpoint_attention = args.checkpoint_attention
#assert args.activations_checkpoint_method is None
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
...@@ -369,6 +358,7 @@ class ParallelAttention(MegatronModule): ...@@ -369,6 +358,7 @@ class ParallelAttention(MegatronModule):
self.core_attention = CoreAttention(self.layer_number, self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type) self.attn_mask_type)
self.checkpoint_core_attention = args.checkpoint_granularity == 'selective'
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = mpu.RowParallelLinear(
...@@ -491,7 +481,7 @@ class ParallelAttention(MegatronModule): ...@@ -491,7 +481,7 @@ class ParallelAttention(MegatronModule):
# core attention computation # core attention computation
# ================================== # ==================================
if self.checkpoint_attention: if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward( context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask) query_layer, key_layer, value_layer, attention_mask)
else: else:
...@@ -564,7 +554,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -564,7 +554,7 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt) sequence_parallel=args.sequence_parallel)
# Self attention. # Self attention.
self.self_attention = ParallelAttention( self.self_attention = ParallelAttention(
...@@ -582,7 +572,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -582,7 +572,7 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt) sequence_parallel=args.sequence_parallel)
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention( self.inter_attention = ParallelAttention(
...@@ -595,7 +585,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -595,7 +585,7 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt) sequence_parallel=args.sequence_parallel)
# MLP # MLP
if args.num_experts is not None: if args.num_experts is not None:
...@@ -747,12 +737,13 @@ class ParallelTransformer(MegatronModule): ...@@ -747,12 +737,13 @@ class ParallelTransformer(MegatronModule):
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method self.checkpoint_granularity = args.checkpoint_granularity
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers self.checkpoint_method = args.checkpoint_method
self.checkpoint_num_layers = args.checkpoint_num_layers
self.distribute_checkpointed_activations = \ self.distribute_checkpointed_activations = \
args.distribute_checkpointed_activations and not args.model_parallel_memory_opt args.distribute_checkpointed_activations and not args.sequence_parallel
self.model_parallel_memory_opt = args.model_parallel_memory_opt self.sequence_parallel = args.sequence_parallel
# Number of layers. # Number of layers.
self.num_layers = mpu.get_num_layers( self.num_layers = mpu.get_num_layers(
...@@ -822,7 +813,7 @@ class ParallelTransformer(MegatronModule): ...@@ -822,7 +813,7 @@ class ParallelTransformer(MegatronModule):
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt) sequence_parallel=args.sequence_parallel)
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[layer_number] return self.layers[layer_number]
...@@ -842,24 +833,24 @@ class ParallelTransformer(MegatronModule): ...@@ -842,24 +833,24 @@ class ParallelTransformer(MegatronModule):
return x_ return x_
return custom_forward return custom_forward
if self.activations_checkpoint_method == 'uniform': if self.checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint # Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk. # the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints. # A method to further reduce memory usage reducing checkpoints.
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers), custom(l, l + self.checkpoint_num_layers),
self.distribute_checkpointed_activations, self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers l += self.checkpoint_num_layers
elif self.activations_checkpoint_method == 'block': elif self.checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual # Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest. # Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation. # A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers): for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers: if l < self.checkpoint_num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_checkpointed_activations, self.distribute_checkpointed_activations,
...@@ -887,7 +878,7 @@ class ParallelTransformer(MegatronModule): ...@@ -887,7 +878,7 @@ class ParallelTransformer(MegatronModule):
inference_params=None): inference_params=None):
# Checks. # Checks.
if inference_params: if inference_params:
assert self.activations_checkpoint_method is None, \ assert self.checkpoint_granularity is None, \
'inference does not work with activation checkpointing' 'inference does not work with activation checkpointing'
if not self.pre_process: if not self.pre_process:
...@@ -915,28 +906,14 @@ class ParallelTransformer(MegatronModule): ...@@ -915,28 +906,14 @@ class ParallelTransformer(MegatronModule):
keep_graph=True, keep_graph=True,
) )
if self.model_parallel_memory_opt: if self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork(): rng_context = mpu.get_cuda_rng_tracker().fork()
# Forward pass.
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
total = 0
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
else: else:
rng_context = contextlib.nullcontext
with rng_context:
# Forward pass. # Forward pass.
if self.activations_checkpoint_method is not None: if self.checkpoint_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask, attention_mask,
encoder_output, encoder_output,
......
...@@ -45,9 +45,6 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, ...@@ -45,9 +45,6 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1, 'partition_dim': -1,
'partition_stride': 1} 'partition_stride': 1}
_TOTAL_INPUT = None
_SUB_GRAD_INPUT = None
def param_is_not_tensor_parallel_duplicate(param): def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or ( param.tensor_model_parallel) or (
...@@ -208,28 +205,32 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -208,28 +205,32 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
Linear layer execution with asynchronous communication and gradient accumulation Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop. fusion in backprop.
""" """
all_gather_buffer = None
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, model_parallel_memory_opt): async_grad_allreduce, sequence_parallel):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce ctx.async_grad_allreduce = async_grad_allreduce
ctx.model_parallel_memory_opt = model_parallel_memory_opt ctx.sequence_parallel = sequence_parallel
if model_parallel_memory_opt: if sequence_parallel:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
#total_input = torch.empty(dim_size, dtype=input.dtype, if LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer is None:
# device=torch.cuda.current_device(), LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer = \
# requires_grad=False) torch.empty(dim_size, dtype=input.dtype,
global _TOTAL_INPUT device=torch.cuda.current_device(),
total_input = _TOTAL_INPUT requires_grad=False)
torch.distributed._all_gather_base(total_input, input, torch.distributed._all_gather_base(
group=get_tensor_model_parallel_group()) LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer,
input,
group=get_tensor_model_parallel_group())
total_input = LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer
else: else:
total_input = input total_input = input
...@@ -244,27 +245,25 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -244,27 +245,25 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
if ctx.model_parallel_memory_opt: if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
#total_input = torch.empty(dim_size, dtype=input.dtype, handle = torch.distributed._all_gather_base(
# device=torch.cuda.current_device(), LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer,
# requires_grad=False) input,
global _TOTAL_INPUT group=get_tensor_model_parallel_group(), async_op=True)
total_input = _TOTAL_INPUT
handle = torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have # Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated # gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
total_input = LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer
else: else:
total_input = input total_input = input
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
if ctx.model_parallel_memory_opt: if ctx.sequence_parallel:
handle.wait() handle.wait()
# Convert the tensor shapes to 2D for execution compatibility # Convert the tensor shapes to 2D for execution compatibility
...@@ -281,7 +280,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -281,7 +280,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# all-reduce scheduled first and have GPU resources allocated # all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
if ctx.model_parallel_memory_opt: if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce assert not ctx.async_grad_allreduce
dim_size = list(input.size()) dim_size = list(input.size())
sub_grad_input = torch.empty(dim_size, dtype=input.dtype, sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
...@@ -303,7 +302,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -303,7 +302,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.model_parallel_memory_opt: if ctx.sequence_parallel:
handle.wait() handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None return sub_grad_input, grad_weight, grad_bias, None, None, None
...@@ -390,34 +389,28 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -390,34 +389,28 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
args.async_tensor_model_parallel_allreduce and args.async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.model_parallel_memory_opt = ( self.sequence_parallel = (
args.model_parallel_memory_opt and args.sequence_parallel and
world_size > 1) world_size > 1)
assert not self.async_tensor_model_parallel_allreduce or \ assert not self.async_tensor_model_parallel_allreduce or \
not self.model_parallel_memory_opt not self.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
global _TOTAL_INPUT
if _TOTAL_INPUT is None:
_TOTAL_INPUT = torch.empty((args.seq_length, args.micro_batch_size, args.hidden_size), dtype=torch.bfloat16,
device=torch.cuda.current_device(),
requires_grad=False)
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce or \ if self.async_tensor_model_parallel_allreduce or \
self.model_parallel_memory_opt: self.sequence_parallel:
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = copy_to_tensor_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion, input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce, self.model_parallel_memory_opt) self.async_tensor_model_parallel_allreduce, self.sequence_parallel)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
assert not self.model_parallel_memory_opt assert not self.sequence_parallel
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
...@@ -498,14 +491,14 @@ class RowParallelLinear(torch.nn.Module): ...@@ -498,14 +491,14 @@ class RowParallelLinear(torch.nn.Module):
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(), self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
setattr(self.bias, 'sequence_parallel', args.model_parallel_memory_opt) setattr(self.bias, 'sequence_parallel', args.sequence_parallel)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.model_parallel_memory_opt = args.model_parallel_memory_opt self.sequence_parallel = args.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
...@@ -515,14 +508,14 @@ class RowParallelLinear(torch.nn.Module): ...@@ -515,14 +508,14 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
assert not self.model_parallel_memory_opt assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, None, input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None, None) self.gradient_accumulation_fusion, None, None)
# All-reduce across all the partitions. # All-reduce across all the partitions.
if self.model_parallel_memory_opt: if self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else: else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD from apex.optimizers import FusedSGD as SGD
...@@ -90,6 +91,18 @@ def get_megatron_optimizer(model, ...@@ -90,6 +91,18 @@ def get_megatron_optimizer(model,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps) eps=args.adam_eps)
# preallocating state tensors to avoid fragmentation
for param_group in optimizer.param_groups:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
state = optimizer.state[param]
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(param.data, dtype=torch.float)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(param.data, dtype=torch.float)
elif args.optimizer == 'sgd': elif args.optimizer == 'sgd':
optimizer = SGD(param_groups, optimizer = SGD(param_groups,
lr=args.lr, lr=args.lr,
......
...@@ -264,14 +264,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -264,14 +264,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if param in self.optimizer.state: if param in self.optimizer.state:
self.optimizer.state[main_param] \ self.optimizer.state[main_param] \
= self.optimizer.state.pop(param) = self.optimizer.state.pop(param)
#state = self.optimizer.state[main_param]
#if len(state) == 0:
# # Exponential moving average of gradient values
# state['exp_avg'] = torch.zeros_like(main_param.data)
# # Exponential moving average of squared gradient values
# state['exp_avg_sq'] = torch.zeros_like(main_param.data)
# fp32 params. # fp32 params.
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param) fp32_params_this_group.append(param)
...@@ -289,10 +281,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -289,10 +281,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
fp32_from_float16_params_this_group) fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
......
...@@ -62,7 +62,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -62,7 +62,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
override_scatter_gather_tensors_in_pipeline = False override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline and \ if args.scatter_gather_tensors_in_pipeline and \
not args.model_parallel_memory_opt: not args.sequence_parallel:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0: if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
tensor_chunk_shape = tensor_chunk_shape // \ tensor_chunk_shape = tensor_chunk_shape // \
...@@ -95,7 +95,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -95,7 +95,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Split tensor into smaller chunks if using scatter-gather optimization. # Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and \ if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline and \
not args.model_parallel_memory_opt: not args.sequence_parallel:
if tensor_send_next is not None: if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
...@@ -141,7 +141,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -141,7 +141,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# If using scatter-gather optimization, gather smaller chunks. # If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and \ if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline and \
not args.model_parallel_memory_opt: not args.sequence_parallel:
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
......
...@@ -279,7 +279,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -279,7 +279,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank() pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
args = get_args() args = get_args()
if args.model_parallel_memory_opt: if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else: else:
seq_length = args.seq_length seq_length = args.seq_length
...@@ -519,13 +519,13 @@ def get_tensor_shapes(rank, model_type): ...@@ -519,13 +519,13 @@ def get_tensor_shapes(rank, model_type):
args = get_args() args = get_args()
tensor_shapes = [] tensor_shapes = []
if args.model_parallel_memory_opt: if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else: else:
seq_length = args.seq_length seq_length = args.seq_length
if model_type == ModelType.encoder_and_decoder: if model_type == ModelType.encoder_and_decoder:
if args.model_parallel_memory_opt: if args.sequence_parallel:
decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size() decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size()
else: else:
decoder_seq_length = args.decoder_seq_length decoder_seq_length = args.decoder_seq_length
......
...@@ -422,7 +422,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -422,7 +422,7 @@ def train_step(forward_step_func, data_iterator,
# All-reduce layernorm parameters across model parallel nodes # All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used # when sequence parallelism is used
if mpu.get_tensor_model_parallel_world_size() > 1 and \ if mpu.get_tensor_model_parallel_world_size() > 1 and \
args.model_parallel_memory_opt: args.sequence_parallel:
grads = [] grads = []
for model_module in model: for model_module in model:
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
...@@ -432,8 +432,6 @@ def train_step(forward_step_func, data_iterator, ...@@ -432,8 +432,6 @@ def train_step(forward_step_func, data_iterator,
grad = param.main_grad if args.DDP_impl == 'local' else param.grad grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data) grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads) coalesced = _flatten_dense_tensors(grads)
#TODO VIJAY
#coalesced /= mpu.get_tensor_model_parallel_world_size()
torch.distributed.all_reduce( torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group()) coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors( for buf, synced in zip(grads, _unflatten_dense_tensors(
......
...@@ -59,7 +59,7 @@ def get_batch(data_iterator): ...@@ -59,7 +59,7 @@ def get_batch(data_iterator):
# Unpack. # Unpack.
tokens_ = data_b['text'].long() tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].transpose(0,1).contiguous() labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous() tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids. # Get the masks and postition ids.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment