Commit cfd2e216 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 13b3dca6
......@@ -103,12 +103,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform'
args.checkpoint_granularity = 'full'
args.checkpoint_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
'use --checkpoint-granularity and --checkpoint-method instead. '
'Defaulting to checkpoint-granularity=full and checkpoint-method=uniform.')
del args.checkpoint_activations
# Set input defaults.
......@@ -283,18 +285,26 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
assert args.checkpoint_granularity == 'full', \
'distributed checkpoint activations is only '\
'application to full checkpoint granularity'
assert args.checkpoint_method is not None, \
'for distributed checkpoint activations to work you '\
'need to use a activation-checkpoint method '
'need to use a checkpoint method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed checkpoint activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
# model parallel memory optmization
if args.model_parallel_memory_opt:
assert not args.async_tensor_model_parallel_allreduce
if args.checkpoint_granularity == 'selective':
assert args.checkpoint_method is None, \
'checkpoint method is not yet supported for ' \
'selective checkpointing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optmization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
_print_args(args)
return args
......@@ -476,30 +486,38 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--checkpoint-attention', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--checkpoint-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activatins to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is reverse checkpointed, '
'2) selective: core attention part of the transformer '
'layer is reverse checkpointed.')
group.add_argument('--distribute-checkpointed-activations',
action='store_true',
help='If set, distribute checkpointed activations '
'across model parallel group.')
group.add_argument('--activations-checkpoint-method', type=str, default=None,
group.add_argument('--checkpoint-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'each divided chunk at specified granularity, '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'rest without any checkpointing at specified granularity'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
......@@ -548,8 +566,8 @@ def _add_training_args(parser):
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--model-parallel-memory-opt', action='store_true',
help='Enable model parallel memory optmization.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optmization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
......
......@@ -110,8 +110,11 @@ def post_language_model_processing(lm_output, pooled_output,
binary_logits = binary_head(pooled_output)
if lm_labels is None:
return lm_logits, binary_logits
# [s b h] => [b s h]
return lm_logits.transpose(0,1).contiguous(), binary_logits
else:
# [b s] => [s b]
lm_logits = lm_logits.transpose(0,1).contiguous()
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
......
......@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule):
pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT
pooled_output = lm_output[:, 0, :]
pooled_output = lm_output[0, :, :]
# Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype)
......
......@@ -32,15 +32,18 @@ def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output,
fp16_lm_cross_entropy):
# Output.
# Output. Format [s b h]
output = parallel_lm_logits(
lm_output,
logit_weights,
parallel_output)
if labels is None:
return output
# [s b h] => [b s h]
return output.transpose(0,1).contiguous()
else:
# [b s] => [s b]
labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
......
......@@ -33,11 +33,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
args = get_args()
# Parallel logits.
if args.async_tensor_model_parallel_allreduce or\
args.model_parallel_memory_opt:
args.sequence_parallel:
input_parallel = input_
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.model_parallel_memory_opt
model_parallel and not args.sequence_parallel
else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
......@@ -46,7 +46,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion,
async_grad_allreduce, args.model_parallel_memory_opt)
async_grad_allreduce, args.sequence_parallel)
# Gather if needed.
if parallel_output:
......@@ -107,9 +107,9 @@ class Pooler(MegatronModule):
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# hidden_states: [s, b, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
......@@ -171,7 +171,7 @@ class Embedding(MegatronModule):
self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection
self.model_parallel_memory_opt = args.model_parallel_memory_opt
self.sequence_parallel = args.sequence_parallel
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
......@@ -214,18 +214,17 @@ class Embedding(MegatronModule):
assert self.tokentype_embeddings is None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
embeddings = embeddings.transpose(0, 1).contiguous()
embeddings = embeddings.float()
if self.model_parallel_memory_opt:
if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
# Dropout.
if self.model_parallel_memory_opt:
if self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
......
......@@ -157,8 +157,11 @@ class T5Model(MegatronModule):
self.word_embeddings_weight())
if lm_labels is None:
return lm_logits
# [s b h] => [b s h]
return lm_logits.transpose(0,1).contiguous()
else:
# [b s] => [s b]
lm_labels = lm_lables.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
......
......@@ -15,6 +15,7 @@
"""Transformer."""
import math
import contextlib
import torch
import torch.nn.functional as F
......@@ -27,7 +28,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
_MATMUL_INPUT = None
""" We use the following notation throughout this file:
h: hidden size
......@@ -167,6 +167,8 @@ class SwitchMLP(MegatronModule):
class CoreAttention(MegatronModule):
matmul_input = None
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
......@@ -180,7 +182,7 @@ class CoreAttention(MegatronModule):
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.model_parallel_memory_opt = args.model_parallel_memory_opt
self.sequence_parallel = args.sequence_parallel
projection_size = args.kv_channels * args.num_attention_heads
......@@ -193,15 +195,6 @@ class CoreAttention(MegatronModule):
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
global _MATMUL_INPUT
if _MATMUL_INPUT is None:
_MATMUL_INPUT = torch.empty(
args.micro_batch_size * self.num_attention_heads_per_partition,
args.seq_length,
args.seq_length,
dtype=torch.bfloat16,
device=torch.cuda.current_device())
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
......@@ -220,7 +213,7 @@ class CoreAttention(MegatronModule):
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
......@@ -241,20 +234,18 @@ class CoreAttention(MegatronModule):
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
#matmul_result = torch.empty(
# output_size[0]*output_size[1],
# output_size[2],
# output_size[3],
# dtype=query_layer.dtype,
# device=torch.cuda.current_device())
global _MATMUL_INPUT
matmul_input = _MATMUL_INPUT
# preallocting input tensor: [b * np, sq, sk]
if CoreAttention.matmul_input is None:
CoreAttention.matmul_input = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input,
CoreAttention.matmul_input,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
......@@ -273,7 +264,7 @@ class CoreAttention(MegatronModule):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.model_parallel_memory_opt:
if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
......@@ -334,8 +325,6 @@ class ParallelAttention(MegatronModule):
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
self.checkpoint_attention = args.checkpoint_attention
#assert args.activations_checkpoint_method is None
projection_size = args.kv_channels * args.num_attention_heads
......@@ -369,6 +358,7 @@ class ParallelAttention(MegatronModule):
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
self.checkpoint_core_attention = args.checkpoint_granularity == 'selective'
# Output.
self.dense = mpu.RowParallelLinear(
......@@ -491,7 +481,7 @@ class ParallelAttention(MegatronModule):
# core attention computation
# ==================================
if self.checkpoint_attention:
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
else:
......@@ -564,7 +554,7 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
sequence_parallel=args.sequence_parallel)
# Self attention.
self.self_attention = ParallelAttention(
......@@ -582,7 +572,7 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
sequence_parallel=args.sequence_parallel)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
......@@ -595,7 +585,7 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
sequence_parallel=args.sequence_parallel)
# MLP
if args.num_experts is not None:
......@@ -747,12 +737,13 @@ class ParallelTransformer(MegatronModule):
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.checkpoint_granularity = args.checkpoint_granularity
self.checkpoint_method = args.checkpoint_method
self.checkpoint_num_layers = args.checkpoint_num_layers
self.distribute_checkpointed_activations = \
args.distribute_checkpointed_activations and not args.model_parallel_memory_opt
args.distribute_checkpointed_activations and not args.sequence_parallel
self.model_parallel_memory_opt = args.model_parallel_memory_opt
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = mpu.get_num_layers(
......@@ -822,7 +813,7 @@ class ParallelTransformer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.model_parallel_memory_opt)
sequence_parallel=args.sequence_parallel)
def _get_layer(self, layer_number):
return self.layers[layer_number]
......@@ -842,24 +833,24 @@ class ParallelTransformer(MegatronModule):
return x_
return custom_forward
if self.activations_checkpoint_method == 'uniform':
if self.checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
custom(l, l + self.checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
l += self.checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
elif self.checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
if l < self.checkpoint_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
......@@ -887,7 +878,7 @@ class ParallelTransformer(MegatronModule):
inference_params=None):
# Checks.
if inference_params:
assert self.activations_checkpoint_method is None, \
assert self.checkpoint_granularity is None, \
'inference does not work with activation checkpointing'
if not self.pre_process:
......@@ -915,28 +906,14 @@ class ParallelTransformer(MegatronModule):
keep_graph=True,
)
if self.model_parallel_memory_opt:
with mpu.get_cuda_rng_tracker().fork():
# Forward pass.
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
total = 0
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
if self.sequence_parallel:
rng_context = mpu.get_cuda_rng_tracker().fork()
else:
rng_context = contextlib.nullcontext
with rng_context:
# Forward pass.
if self.activations_checkpoint_method is not None:
if self.checkpoint_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
......
......@@ -45,9 +45,6 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
_TOTAL_INPUT = None
_SUB_GRAD_INPUT = None
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
......@@ -208,28 +205,32 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop.
"""
all_gather_buffer = None
@staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, model_parallel_memory_opt):
async_grad_allreduce, sequence_parallel):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.model_parallel_memory_opt = model_parallel_memory_opt
ctx.sequence_parallel = sequence_parallel
if model_parallel_memory_opt:
if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
#total_input = torch.empty(dim_size, dtype=input.dtype,
# device=torch.cuda.current_device(),
# requires_grad=False)
global _TOTAL_INPUT
total_input = _TOTAL_INPUT
torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group())
if LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer is None:
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
torch.distributed._all_gather_base(
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer,
input,
group=get_tensor_model_parallel_group())
total_input = LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer
else:
total_input = input
......@@ -244,27 +245,25 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
if ctx.model_parallel_memory_opt:
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
#total_input = torch.empty(dim_size, dtype=input.dtype,
# device=torch.cuda.current_device(),
# requires_grad=False)
global _TOTAL_INPUT
total_input = _TOTAL_INPUT
handle = torch.distributed._all_gather_base(
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer,
input,
group=get_tensor_model_parallel_group(), async_op=True)
handle = torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
total_input = LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)
if ctx.model_parallel_memory_opt:
if ctx.sequence_parallel:
handle.wait()
# Convert the tensor shapes to 2D for execution compatibility
......@@ -281,7 +280,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.model_parallel_memory_opt:
if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce
dim_size = list(input.size())
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
......@@ -303,7 +302,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.model_parallel_memory_opt:
if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None
......@@ -390,34 +389,28 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = (
args.async_tensor_model_parallel_allreduce and
world_size > 1)
self.model_parallel_memory_opt = (
args.model_parallel_memory_opt and
self.sequence_parallel = (
args.sequence_parallel and
world_size > 1)
assert not self.async_tensor_model_parallel_allreduce or \
not self.model_parallel_memory_opt
not self.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
global _TOTAL_INPUT
if _TOTAL_INPUT is None:
_TOTAL_INPUT = torch.empty((args.seq_length, args.micro_batch_size, args.hidden_size), dtype=torch.bfloat16,
device=torch.cuda.current_device(),
requires_grad=False)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce or \
self.model_parallel_memory_opt:
self.sequence_parallel:
input_parallel = input_
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce, self.model_parallel_memory_opt)
self.async_tensor_model_parallel_allreduce, self.sequence_parallel)
if self.gather_output:
# All-gather across the partitions.
assert not self.model_parallel_memory_opt
assert not self.sequence_parallel
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
......@@ -498,14 +491,14 @@ class RowParallelLinear(torch.nn.Module):
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype))
setattr(self.bias, 'sequence_parallel', args.model_parallel_memory_opt)
setattr(self.bias, 'sequence_parallel', args.sequence_parallel)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.model_parallel_memory_opt = args.model_parallel_memory_opt
self.sequence_parallel = args.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
......@@ -515,14 +508,14 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel:
input_parallel = input_
else:
assert not self.model_parallel_memory_opt
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None, None)
# All-reduce across all the partitions.
if self.model_parallel_memory_opt:
if self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
......@@ -90,6 +91,18 @@ def get_megatron_optimizer(model,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
# preallocating state tensors to avoid fragmentation
for param_group in optimizer.param_groups:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
state = optimizer.state[param]
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(param.data, dtype=torch.float)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(param.data, dtype=torch.float)
elif args.optimizer == 'sgd':
optimizer = SGD(param_groups,
lr=args.lr,
......
......@@ -264,14 +264,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if param in self.optimizer.state:
self.optimizer.state[main_param] \
= self.optimizer.state.pop(param)
#state = self.optimizer.state[main_param]
#if len(state) == 0:
# # Exponential moving average of gradient values
# state['exp_avg'] = torch.zeros_like(main_param.data)
# # Exponential moving average of squared gradient values
# state['exp_avg_sq'] = torch.zeros_like(main_param.data)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
......@@ -289,10 +281,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
......
......@@ -62,7 +62,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline and \
not args.model_parallel_memory_opt:
not args.sequence_parallel:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
tensor_chunk_shape = tensor_chunk_shape // \
......@@ -95,7 +95,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Split tensor into smaller chunks if using scatter-gather optimization.
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline and \
not args.model_parallel_memory_opt:
not args.sequence_parallel:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
......@@ -141,7 +141,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline and \
not args.model_parallel_memory_opt:
not args.sequence_parallel:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
......
......@@ -279,7 +279,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
args = get_args()
if args.model_parallel_memory_opt:
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
......@@ -519,13 +519,13 @@ def get_tensor_shapes(rank, model_type):
args = get_args()
tensor_shapes = []
if args.model_parallel_memory_opt:
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
if model_type == ModelType.encoder_and_decoder:
if args.model_parallel_memory_opt:
if args.sequence_parallel:
decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size()
else:
decoder_seq_length = args.decoder_seq_length
......
......@@ -422,7 +422,7 @@ def train_step(forward_step_func, data_iterator,
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if mpu.get_tensor_model_parallel_world_size() > 1 and \
args.model_parallel_memory_opt:
args.sequence_parallel:
grads = []
for model_module in model:
unwrapped_model = unwrap_model(
......@@ -432,8 +432,6 @@ def train_step(forward_step_func, data_iterator,
grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads)
#TODO VIJAY
#coalesced /= mpu.get_tensor_model_parallel_world_size()
torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
......
......@@ -59,7 +59,7 @@ def get_batch(data_iterator):
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].transpose(0,1).contiguous()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment