Commit ee327acd authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

reordering perf fix

parent 42d21122
...@@ -172,19 +172,19 @@ class ParallelSelfAttention(MegatronModule): ...@@ -172,19 +172,19 @@ class ParallelSelfAttention(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer): def _transpose_last_dim(self, mixed_layer, num_splits):
"""[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose) """[s, b, num_splits * np * hn]
[s, b, hp, 3] -->(view) [s, b, 3 * hp] """ -->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
input_shape = mixed_layer.size(); input_shape = mixed_layer.size();
last_dim = input_shape[-1]
assert last_dim % 3 == 0, "expected QKV dimension"
last_dim_split = last_dim // 3
intermediate_shape = input_shape[:-1] +\ intermediate_shape = input_shape[:-1] +\
(3, last_dim_split) (num_splits, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous() mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
mixed_layer = mixed_layer.view(*input_shape) mixed_layer = mixed_layer.view(*input_shape)
return mixed_layer return mixed_layer
...@@ -197,25 +197,25 @@ class ParallelSelfAttention(MegatronModule): ...@@ -197,25 +197,25 @@ class ParallelSelfAttention(MegatronModule):
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
# Attention heads [sq, b, hp] --> [sq, b, hp * 3] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states) mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version() checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None and \ if checkpoint_version is not None and \
checkpoint_version == 0: checkpoint_version == 0:
# [sq, b, 3 * hp] --> [sq, b, hp * 3] # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer) mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3)
# [sq, b, hp * 3] --> [sq, b, np, hn, 3] # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, 3) 3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, hn, 3] --> 3 [sq, b, np, hn] # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer = mixed_x_layer[:,:,:,:,0] (query_layer,
key_layer = mixed_x_layer[:,:,:,:,1] key_layer,
value_layer = mixed_x_layer[:,:,:,:,2] value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
# ================================== # ==================================
# Adjust key and value for inference # Adjust key and value for inference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment