Commit eec218d8 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

sequence parallelism for embedding dropout and last linear layer + memory optimizations

parent 53718d4c
...@@ -37,7 +37,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -37,7 +37,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
input_parallel = input_ input_parallel = input_
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel model_parallel and not args.model_parallel_memory_opt
else: else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False async_grad_allreduce = False
...@@ -46,7 +46,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -46,7 +46,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias, input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion, args.gradient_accumulation_fusion,
async_grad_allreduce, None) async_grad_allreduce, args.model_parallel_memory_opt)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
...@@ -170,6 +170,8 @@ class Embedding(MegatronModule): ...@@ -170,6 +170,8 @@ class Embedding(MegatronModule):
else: else:
self.tokentype_embeddings = None self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection
self.model_parallel_memory_opt = args.model_parallel_memory_opt
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
...@@ -211,7 +213,22 @@ class Embedding(MegatronModule): ...@@ -211,7 +213,22 @@ class Embedding(MegatronModule):
else: else:
assert self.tokentype_embeddings is None assert self.tokentype_embeddings is None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
embeddings = embeddings.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
# Dropout. # Dropout.
if self.model_parallel_memory_opt:
with mpu.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
return embeddings return embeddings
......
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_timers, get_args, print_rank_last, print_rank_0
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
...@@ -27,6 +27,8 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax ...@@ -27,6 +27,8 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
_MATMUL_INPUT = None
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
n: number of attention heads n: number of attention heads
...@@ -42,7 +44,6 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu ...@@ -42,7 +44,6 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters hyperparameters: transformer hyperparameters
""" """
class DropPath(MegatronModule): class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample """Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks). (when applied in main path of residual blocks).
...@@ -189,6 +190,17 @@ class CoreAttention(MegatronModule): ...@@ -189,6 +190,17 @@ class CoreAttention(MegatronModule):
world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = mpu.divide(
projection_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
global _MATMUL_INPUT
if _MATMUL_INPUT is None:
_MATMUL_INPUT = torch.empty(
args.micro_batch_size * self.num_attention_heads_per_partition,
args.seq_length,
args.seq_length,
dtype=torch.bfloat16,
device=torch.cuda.current_device())
coeff = None coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
...@@ -230,16 +242,19 @@ class CoreAttention(MegatronModule): ...@@ -230,16 +242,19 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk] # preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty( #matmul_result = torch.empty(
output_size[0]*output_size[1], # output_size[0]*output_size[1],
output_size[2], # output_size[2],
output_size[3], # output_size[3],
dtype=query_layer.dtype, # dtype=query_layer.dtype,
device=torch.cuda.current_device()) # device=torch.cuda.current_device())
global _MATMUL_INPUT
matmul_input = _MATMUL_INPUT
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
matmul_result, matmul_input,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor)) beta=0.0, alpha=(1.0/self.norm_factor))
...@@ -838,6 +853,7 @@ class ParallelTransformer(MegatronModule): ...@@ -838,6 +853,7 @@ class ParallelTransformer(MegatronModule):
self.distribute_checkpointed_activations, self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block': elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual # Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest. # Transformer layers and skip the rest.
...@@ -869,25 +885,12 @@ class ParallelTransformer(MegatronModule): ...@@ -869,25 +885,12 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None): inference_params=None):
# Checks. # Checks.
if inference_params: if inference_params:
assert self.activations_checkpoint_method is None, \ assert self.activations_checkpoint_method is None, \
'inference does not work with activation checkpointing' 'inference does not work with activation checkpointing'
if self.pre_process: if not self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
hidden_states = mpu.scatter_to_sequence_parallel_region(hidden_states)
else:
# See set_input_tensor() # See set_input_tensor()
hidden_states = self.input_tensor hidden_states = self.input_tensor
...@@ -908,17 +911,10 @@ class ParallelTransformer(MegatronModule): ...@@ -908,17 +911,10 @@ class ParallelTransformer(MegatronModule):
# is called here to be future-proof and corner-case-proof. # is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor( hidden_states = mpu.make_viewless_tensor(
hidden_states, hidden_states,
requires_grad = True, requires_grad=True,
keep_graph = True, keep_graph=True,
) )
# Transpose encoder output.
if encoder_output is not None and \
not self.model_parallel_memory_opt:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
encoder_output = mpu.scatter_to_sequence_parallel_region(encoder_output)
if self.model_parallel_memory_opt: if self.model_parallel_memory_opt:
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
# Forward pass. # Forward pass.
...@@ -928,6 +924,7 @@ class ParallelTransformer(MegatronModule): ...@@ -928,6 +924,7 @@ class ParallelTransformer(MegatronModule):
encoder_output, encoder_output,
enc_dec_attn_mask) enc_dec_attn_mask)
else: else:
total = 0
for index in range(self.num_layers): for index in range(self.num_layers):
layer = self._get_layer(index) layer = self._get_layer(index)
hidden_states = layer( hidden_states = layer(
...@@ -936,6 +933,7 @@ class ParallelTransformer(MegatronModule): ...@@ -936,6 +933,7 @@ class ParallelTransformer(MegatronModule):
encoder_output=encoder_output, encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask, enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params) inference_params=inference_params)
else: else:
# Forward pass. # Forward pass.
if self.activations_checkpoint_method is not None: if self.activations_checkpoint_method is not None:
...@@ -955,20 +953,6 @@ class ParallelTransformer(MegatronModule): ...@@ -955,20 +953,6 @@ class ParallelTransformer(MegatronModule):
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
if self.layer_type == LayerType.encoder and \ return hidden_states
self.model_type == ModelType.encoder_and_decoder and \
self.model_parallel_memory_opt:
output = hidden_states
else:
if self.model_parallel_memory_opt:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
output = hidden_states.transpose(0, 1).contiguous()
else:
output = hidden_states
return output
...@@ -41,11 +41,12 @@ from .utils import split_tensor_along_last_dim ...@@ -41,11 +41,12 @@ from .utils import split_tensor_along_last_dim
from .utils import VocabUtility from .utils import VocabUtility
from megatron import get_args from megatron import get_args
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1, 'partition_dim': -1,
'partition_stride': 1} 'partition_stride': 1}
_TOTAL_INPUT = None
_SUB_GRAD_INPUT = None
def param_is_not_tensor_parallel_duplicate(param): def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and return (hasattr(param, 'tensor_model_parallel') and
...@@ -221,9 +222,11 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -221,9 +222,11 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size = list(input.size()) dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
total_input = torch.empty(dim_size, dtype=input.dtype, #total_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(), # device=torch.cuda.current_device(),
requires_grad=False) # requires_grad=False)
global _TOTAL_INPUT
total_input = _TOTAL_INPUT
torch.distributed._all_gather_base(total_input, input, torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group()) group=get_tensor_model_parallel_group())
...@@ -246,9 +249,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -246,9 +249,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size = list(input.size()) dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
total_input = torch.empty(dim_size, dtype=input.dtype, #total_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(), # device=torch.cuda.current_device(),
requires_grad=False) # requires_grad=False)
global _TOTAL_INPUT
total_input = _TOTAL_INPUT
handle = torch.distributed._all_gather_base(total_input, input, handle = torch.distributed._all_gather_base(total_input, input,
group=get_tensor_model_parallel_group(), async_op=True) group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have # Delay the start of intput gradient computation shortly (3us) to have
...@@ -390,6 +396,11 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -390,6 +396,11 @@ class ColumnParallelLinear(torch.nn.Module):
assert not self.async_tensor_model_parallel_allreduce or \ assert not self.async_tensor_model_parallel_allreduce or \
not self.model_parallel_memory_opt not self.model_parallel_memory_opt
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
global _TOTAL_INPUT
if _TOTAL_INPUT is None:
_TOTAL_INPUT = torch.empty((args.seq_length, args.micro_batch_size, args.hidden_size), dtype=torch.bfloat16,
device=torch.cuda.current_device(),
requires_grad=False)
def forward(self, input_): def forward(self, input_):
......
...@@ -265,6 +265,13 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -265,6 +265,13 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
self.optimizer.state[main_param] \ self.optimizer.state[main_param] \
= self.optimizer.state.pop(param) = self.optimizer.state.pop(param)
#state = self.optimizer.state[main_param]
#if len(state) == 0:
# # Exponential moving average of gradient values
# state['exp_avg'] = torch.zeros_like(main_param.data)
# # Exponential moving average of squared gradient values
# state['exp_avg_sq'] = torch.zeros_like(main_param.data)
# fp32 params. # fp32 params.
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param) fp32_params_this_group.append(param)
...@@ -284,7 +291,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -284,7 +291,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# Leverage state_dict() and load_state_dict() to # Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) # self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
......
...@@ -517,8 +517,15 @@ def get_tensor_shapes(rank, model_type): ...@@ -517,8 +517,15 @@ def get_tensor_shapes(rank, model_type):
if args.model_parallel_memory_opt: if args.model_parallel_memory_opt:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
if model_type == ModelType.encoder_and_decoder: if model_type == ModelType.encoder_and_decoder:
if args.model_parallel_memory_opt:
decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size() decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size()
else:
decoder_seq_length = args.decoder_seq_length
if mpu.is_pipeline_stage_before_split(rank): if mpu.is_pipeline_stage_before_split(rank):
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size)) tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
else: else:
...@@ -526,21 +533,6 @@ def get_tensor_shapes(rank, model_type): ...@@ -526,21 +533,6 @@ def get_tensor_shapes(rank, model_type):
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size)) tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
else: else:
tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size)) tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size))
return tensor_shapes
if model_type == ModelType.encoder_and_decoder:
if mpu.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
if mpu.is_pipeline_stage_before_split(rank+1):
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else:
tensor_shapes.append((args.decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else:
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
return tensor_shapes return tensor_shapes
......
...@@ -432,7 +432,8 @@ def train_step(forward_step_func, data_iterator, ...@@ -432,7 +432,8 @@ def train_step(forward_step_func, data_iterator,
grad = param.main_grad if args.DDP_impl == 'local' else param.grad grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data) grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads) coalesced = _flatten_dense_tensors(grads)
coalesced /= mpu.get_tensor_model_parallel_world_size() #TODO VIJAY
#coalesced /= mpu.get_tensor_model_parallel_world_size()
torch.distributed.all_reduce( torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group()) coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors( for buf, synced in zip(grads, _unflatten_dense_tensors(
......
...@@ -59,7 +59,7 @@ def get_batch(data_iterator): ...@@ -59,7 +59,7 @@ def get_batch(data_iterator):
# Unpack. # Unpack.
tokens_ = data_b['text'].long() tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous() labels = tokens_[:, 1:].transpose(0,1).contiguous()
tokens = tokens_[:, :-1].contiguous() tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids. # Get the masks and postition ids.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment