Commit dd96d402 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

bug fixes

parent 269f28f7
......@@ -188,6 +188,7 @@ class ParallelAttention(MegatronModule):
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
self.model_parallel_memory_opt = args.model_parallel_memory_opt
projection_size = args.kv_channels * args.num_attention_heads
......@@ -391,7 +392,11 @@ class ParallelAttention(MegatronModule):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
if not self.model_parallel_memory_opt:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
......@@ -865,32 +870,51 @@ class ParallelTransformer(MegatronModule):
if self.model_parallel_memory_opt:
encoder_output = mpu.scatter_to_sequence_parallel_region(encoder_output)
# Forward pass.
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
if self.model_parallel_memory_opt:
with mpu.get_cuda_rng_tracker().fork():
# Forward pass.
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Forward pass.
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = self.final_layernorm(hidden_states)
if self.layer_type==LayerType.encoder and \
self.model_type==ModelType.encoder_and_decoder and \
if self.layer_type == LayerType.encoder and \
self.model_type == ModelType.encoder_and_decoder and \
self.model_parallel_memory_opt:
output = hidden_states
else:
if self.model_parallel_memory_opt:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
......
......@@ -215,7 +215,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.model_parallel_memory_opt = model_parallel_memory_opt
if model_parallel_memory_opt:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
......@@ -487,6 +487,8 @@ class RowParallelLinear(torch.nn.Module):
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype))
setattr(self.bias, 'sequence_parallel', args.model_parallel_memory_opt)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
......@@ -496,6 +498,7 @@ class RowParallelLinear(torch.nn.Module):
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_):
# Set up backprop all-reduce.
if self.input_is_parallel:
......
......@@ -67,7 +67,7 @@ def _split_along_first_dim(input_):
rank = get_tensor_model_parallel_rank()
dim_offset = rank * (local_dim_size)
output = input_[dim_offset:dim_offset+local_dim_size]
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
return output
......@@ -106,33 +106,27 @@ def _gather_along_first_dim(input_):
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
torch.distributed._all_gather_base(output, input_,
device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
if get_tensor_model_parallel_world_size() == 1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0
dim_size[0]= dim_size[0] // world_size
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# reduce_scatter
torch.distributed._reduce_scatter_base(output, input_,
device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment