Commit 6abf39be authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Only transpose hidden_states when necessary

parent 57c3b364
......@@ -552,7 +552,7 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# Checks
# Checks.
if layer_past is not None:
assert get_key_value, \
'for not None values in layer_past, ' \
......@@ -562,7 +562,8 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with ' \
'activation checkpointing'
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
if mpu.is_pipeline_first_stage():
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.checkpoint_activations:
......@@ -584,11 +585,10 @@ class ParallelTransformer(MegatronModule):
hidden_states, present = hidden_states
presents.append(present)
# reverting data format change [s b h] --> [b s h]
hidden_states = hidden_states.transpose(0, 1).contiguous()
# Final layer norm.
if mpu.is_pipeline_last_stage():
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
......
......@@ -245,7 +245,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.batch_size, args.seq_length, args.hidden_size)
tensor_shape = (args.seq_length, args.batch_size, args.hidden_size)
if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment