Commit eec218d8 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

sequence parallelism for embedding dropout and last linear layer + memory optimizations

parent 53718d4c
......@@ -37,7 +37,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
input_parallel = input_
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel
model_parallel and not args.model_parallel_memory_opt
else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
......@@ -46,7 +46,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion,
async_grad_allreduce, None)
async_grad_allreduce, args.model_parallel_memory_opt)
# Gather if needed.
if parallel_output:
......@@ -170,6 +170,8 @@ class Embedding(MegatronModule):
else:
self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection
self.model_parallel_memory_opt = args.model_parallel_memory_opt
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
......@@ -211,8 +213,23 @@ class Embedding(MegatronModule):
else:
assert self.tokentype_embeddings is None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
embeddings = embeddings.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
# Dropout.
embeddings = self.embedding_dropout(embeddings)
if self.model_parallel_memory_opt:
with mpu.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
......
......@@ -18,7 +18,7 @@ import math
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_timers, get_args, print_rank_last, print_rank_0
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
......@@ -27,6 +27,8 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
_MATMUL_INPUT = None
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
......@@ -42,7 +44,6 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
......@@ -189,7 +190,18 @@ class CoreAttention(MegatronModule):
world_size)
self.hidden_size_per_attention_head = mpu.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
global _MATMUL_INPUT
if _MATMUL_INPUT is None:
_MATMUL_INPUT = torch.empty(
args.micro_batch_size * self.num_attention_heads_per_partition,
args.seq_length,
args.seq_length,
dtype=torch.bfloat16,
device=torch.cuda.current_device())
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
......@@ -230,16 +242,19 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
#matmul_result = torch.empty(
# output_size[0]*output_size[1],
# output_size[2],
# output_size[3],
# dtype=query_layer.dtype,
# device=torch.cuda.current_device())
global _MATMUL_INPUT
matmul_input = _MATMUL_INPUT
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
matmul_input,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
......@@ -838,6 +853,7 @@ class ParallelTransformer(MegatronModule):
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
......@@ -869,25 +885,12 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# Checks.
if inference_params:
assert self.activations_checkpoint_method is None, \
'inference does not work with activation checkpointing'
if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
hidden_states = mpu.scatter_to_sequence_parallel_region(hidden_states)
else:
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
......@@ -908,17 +911,10 @@ class ParallelTransformer(MegatronModule):
# is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor(
hidden_states,
requires_grad = True,
keep_graph = True,
requires_grad=True,
keep_graph=True,
)
# Transpose encoder output.
if encoder_output is not None and \
not self.model_parallel_memory_opt:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
encoder_output = mpu.scatter_to_sequence_parallel_region(encoder_output)
if self.model_parallel_memory_opt:
with mpu.get_cuda_rng_tracker().fork():
# Forward pass.
......@@ -928,6 +924,7 @@ class ParallelTransformer(MegatronModule):
encoder_output,
enc_dec_attn_mask)
else:
total = 0
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
......@@ -936,6 +933,7 @@ class ParallelTransformer(MegatronModule):
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
else:
# Forward pass.
if self.activations_checkpoint_method is not None:
......@@ -955,20 +953,6 @@ class ParallelTransformer(MegatronModule):
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = self.final_layernorm(hidden_states)
if self.layer_type == LayerType.encoder and \
self.model_type == ModelType.encoder_and_decoder and \
self.model_parallel_memory_opt:
output = hidden_states
else:
if self.model_parallel_memory_opt:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
output = hidden_states.transpose(0, 1).contiguous()
else:
output = hidden_states
return output
return hidden_states
......@@ -41,11 +41,12 @@ from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
from megatron import get_args
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
_TOTAL_INPUT = None
_SUB_GRAD_INPUT = None
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
......@@ -221,9 +222,11 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
total_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
#total_input = torch.empty(dim_size, dtype=input.dtype,
# device=torch.cuda.current_device(),
# requires_grad=False)
global _TOTAL_INPUT
total_input = _TOTAL_INPUT
torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group())
......@@ -246,9 +249,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
total_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
#total_input = torch.empty(dim_size, dtype=input.dtype,
# device=torch.cuda.current_device(),
# requires_grad=False)
global _TOTAL_INPUT
total_input = _TOTAL_INPUT
handle = torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have
......@@ -279,8 +285,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
assert not ctx.async_grad_allreduce
dim_size = list(input.size())
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
device=torch.cuda.current_device(),
requires_grad=False)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(),
......@@ -390,6 +396,11 @@ class ColumnParallelLinear(torch.nn.Module):
assert not self.async_tensor_model_parallel_allreduce or \
not self.model_parallel_memory_opt
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
global _TOTAL_INPUT
if _TOTAL_INPUT is None:
_TOTAL_INPUT = torch.empty((args.seq_length, args.micro_batch_size, args.hidden_size), dtype=torch.bfloat16,
device=torch.cuda.current_device(),
requires_grad=False)
def forward(self, input_):
......
......@@ -264,6 +264,13 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if param in self.optimizer.state:
self.optimizer.state[main_param] \
= self.optimizer.state.pop(param)
#state = self.optimizer.state[main_param]
#if len(state) == 0:
# # Exponential moving average of gradient values
# state['exp_avg'] = torch.zeros_like(main_param.data)
# # Exponential moving average of squared gradient values
# state['exp_avg_sq'] = torch.zeros_like(main_param.data)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
......@@ -284,8 +291,8 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
# self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
......
......@@ -517,30 +517,22 @@ def get_tensor_shapes(rank, model_type):
if args.model_parallel_memory_opt:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
if model_type == ModelType.encoder_and_decoder:
else:
seq_length = args.seq_length
if model_type == ModelType.encoder_and_decoder:
if args.model_parallel_memory_opt:
decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size()
if mpu.is_pipeline_stage_before_split(rank):
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
decoder_seq_length = args.decoder_seq_length
return tensor_shapes
if model_type == ModelType.encoder_and_decoder:
if mpu.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
if mpu.is_pipeline_stage_before_split(rank+1):
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((args.decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
tensor_shapes.append((decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
return tensor_shapes
......
......@@ -432,7 +432,8 @@ def train_step(forward_step_func, data_iterator,
grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_tensor_model_parallel_world_size()
#TODO VIJAY
#coalesced /= mpu.get_tensor_model_parallel_world_size()
torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
......
......@@ -59,7 +59,7 @@ def get_batch(data_iterator):
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
labels = tokens_[:, 1:].transpose(0,1).contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment