Commit 6abf39be authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Only transpose hidden_states when necessary

parent 57c3b364
...@@ -552,7 +552,7 @@ class ParallelTransformer(MegatronModule): ...@@ -552,7 +552,7 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
# Checks # Checks.
if layer_past is not None: if layer_past is not None:
assert get_key_value, \ assert get_key_value, \
'for not None values in layer_past, ' \ 'for not None values in layer_past, ' \
...@@ -562,8 +562,9 @@ class ParallelTransformer(MegatronModule): ...@@ -562,8 +562,9 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with ' \ 'get_key_value does not work with ' \
'activation checkpointing' 'activation checkpointing'
# data format change to avoid explicit tranposes : [b s h] --> [s b h] if mpu.is_pipeline_first_stage():
hidden_states = hidden_states.transpose(0, 1).contiguous() # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
...@@ -584,11 +585,10 @@ class ParallelTransformer(MegatronModule): ...@@ -584,11 +585,10 @@ class ParallelTransformer(MegatronModule):
hidden_states, present = hidden_states hidden_states, present = hidden_states
presents.append(present) presents.append(present)
# reverting data format change [s b h] --> [b s h]
hidden_states = hidden_states.transpose(0, 1).contiguous()
# Final layer norm. # Final layer norm.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
else: else:
output = hidden_states output = hidden_states
......
...@@ -245,7 +245,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward) ...@@ -245,7 +245,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
# if needed. # if needed.
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
tensor_shape = (args.batch_size, args.seq_length, args.hidden_size) tensor_shape = (args.seq_length, args.batch_size, args.hidden_size)
if recv_forward: if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape, tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True, requires_grad=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment