Commit dd96d402 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

bug fixes

parent 269f28f7
...@@ -188,6 +188,7 @@ class ParallelAttention(MegatronModule): ...@@ -188,6 +188,7 @@ class ParallelAttention(MegatronModule):
self.attention_type = attention_type self.attention_type = attention_type
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype self.params_dtype = args.params_dtype
self.model_parallel_memory_opt = args.model_parallel_memory_opt
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
...@@ -391,7 +392,11 @@ class ParallelAttention(MegatronModule): ...@@ -391,7 +392,11 @@ class ParallelAttention(MegatronModule):
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
if not self.model_parallel_memory_opt:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
# ========================= # =========================
...@@ -865,32 +870,51 @@ class ParallelTransformer(MegatronModule): ...@@ -865,32 +870,51 @@ class ParallelTransformer(MegatronModule):
if self.model_parallel_memory_opt: if self.model_parallel_memory_opt:
encoder_output = mpu.scatter_to_sequence_parallel_region(encoder_output) encoder_output = mpu.scatter_to_sequence_parallel_region(encoder_output)
# Forward pass. if self.model_parallel_memory_opt:
if self.activations_checkpoint_method is not None: with mpu.get_cuda_rng_tracker().fork():
hidden_states = self._checkpointed_forward(hidden_states, # Forward pass.
attention_mask, if self.activations_checkpoint_method is not None:
encoder_output, hidden_states = self._checkpointed_forward(hidden_states,
enc_dec_attn_mask) attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
else: else:
for index in range(self.num_layers): # Forward pass.
layer = self._get_layer(index) if self.activations_checkpoint_method is not None:
hidden_states = layer( hidden_states = self._checkpointed_forward(hidden_states,
hidden_states, attention_mask,
attention_mask, encoder_output,
encoder_output=encoder_output, enc_dec_attn_mask)
enc_dec_attn_mask=enc_dec_attn_mask, else:
inference_params=inference_params) for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process:
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
if self.layer_type==LayerType.encoder and \ if self.layer_type == LayerType.encoder and \
self.model_type==ModelType.encoder_and_decoder and \ self.model_type == ModelType.encoder_and_decoder and \
self.model_parallel_memory_opt: self.model_parallel_memory_opt:
output = hidden_states output = hidden_states
else: else:
if self.model_parallel_memory_opt: if self.model_parallel_memory_opt:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states) hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
......
...@@ -215,7 +215,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -215,7 +215,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce ctx.async_grad_allreduce = async_grad_allreduce
ctx.model_parallel_memory_opt = model_parallel_memory_opt ctx.model_parallel_memory_opt = model_parallel_memory_opt
if model_parallel_memory_opt: if model_parallel_memory_opt:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
...@@ -487,6 +487,8 @@ class RowParallelLinear(torch.nn.Module): ...@@ -487,6 +487,8 @@ class RowParallelLinear(torch.nn.Module):
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(), self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
setattr(self.bias, 'sequence_parallel', args.model_parallel_memory_opt)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
...@@ -496,6 +498,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -496,6 +498,7 @@ class RowParallelLinear(torch.nn.Module):
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
if self.input_is_parallel: if self.input_is_parallel:
......
...@@ -67,7 +67,7 @@ def _split_along_first_dim(input_): ...@@ -67,7 +67,7 @@ def _split_along_first_dim(input_):
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
dim_offset = rank * (local_dim_size) dim_offset = rank * (local_dim_size)
output = input_[dim_offset:dim_offset+local_dim_size] output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
return output return output
...@@ -106,33 +106,27 @@ def _gather_along_first_dim(input_): ...@@ -106,33 +106,27 @@ def _gather_along_first_dim(input_):
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype, output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device())
requires_grad=False) torch.distributed._all_gather_base(output, input_.contiguous(),
torch.distributed._all_gather_base(output, input_,
group=get_tensor_model_parallel_group()) group=get_tensor_model_parallel_group())
return output return output
def _reduce_scatter_along_first_dim(input_): def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group.""" """Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1: if get_tensor_model_parallel_world_size() == 1:
return input_ return input_
dim_size = list(input_.size()) dim_size = list(input_.size())
assert dim_size[0] % world_size == 0 assert dim_size[0] % world_size == 0
dim_size[0]= dim_size[0] // world_size dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype, output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device())
requires_grad=False) torch.distributed._reduce_scatter_base(output, input_.contiguous(),
# reduce_scatter
torch.distributed._reduce_scatter_base(output, input_,
group=get_tensor_model_parallel_group()) group=get_tensor_model_parallel_group())
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment