Commit 538f0e05 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'vijay/ordering_perf_fix' into 'main'

reordering perf fix

See merge request ADLR/megatron-lm!154
parents 42d21122 01dffdb5
...@@ -101,7 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -101,7 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 1.0 state_dict['checkpoint_version'] = 2.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint() state_dict['model'] = model.state_dict_for_save_checkpoint()
......
...@@ -172,17 +172,30 @@ class ParallelSelfAttention(MegatronModule): ...@@ -172,17 +172,30 @@ class ParallelSelfAttention(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer): def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
"""[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose)
[s, b, hp, 3] -->(view) [s, b, 3 * hp] """
input_shape = mixed_layer.size(); input_shape = mixed_layer.size();
last_dim = input_shape[-1] if num_splits_first:
assert last_dim % 3 == 0, "expected QKV dimension" """[s, b, num_splits * np * hn]
last_dim_split = last_dim // 3 -->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\ intermediate_shape = input_shape[:-1] +\
(3, last_dim_split) (num_splits, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
else:
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, num_splits)
mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous() mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
mixed_layer = mixed_layer.view(*input_shape) mixed_layer = mixed_layer.view(*input_shape)
...@@ -197,25 +210,28 @@ class ParallelSelfAttention(MegatronModule): ...@@ -197,25 +210,28 @@ class ParallelSelfAttention(MegatronModule):
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
# Attention heads [sq, b, hp] --> [sq, b, hp * 3] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states) mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version() checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None and \ if checkpoint_version is not None:
checkpoint_version == 0: if checkpoint_version == 0:
# [sq, b, 3 * hp] --> [sq, b, hp * 3] # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer) mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1.0:
# [sq, b, hp * 3] --> [sq, b, np, hn, 3] # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, 3) 3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, hn, 3] --> 3 [sq, b, np, hn] # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer = mixed_x_layer[:,:,:,:,0] (query_layer,
key_layer = mixed_x_layer[:,:,:,:,1] key_layer,
value_layer = mixed_x_layer[:,:,:,:,2] value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
# ================================== # ==================================
# Adjust key and value for inference # Adjust key and value for inference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment