"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "3a41edb8c29eae02fd286302ab85d809c73fff18"
Commit 0bc75448 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 9b0083ea
...@@ -172,14 +172,14 @@ class ParallelSelfAttention(MegatronModule): ...@@ -172,14 +172,14 @@ class ParallelSelfAttention(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_index): def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
input_shape = mixed_layer.size(); input_shape = mixed_layer.size();
if num_splits_index == 0: if num_splits_first:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\ intermediate_shape = input_shape[:-1] +\
(num_splits, self.num_attention_heads_per_partition, (num_splits, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head) self.hidden_size_per_attention_head)
...@@ -187,7 +187,11 @@ class ParallelSelfAttention(MegatronModule): ...@@ -187,7 +187,11 @@ class ParallelSelfAttention(MegatronModule):
mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-2, -3).contiguous() mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
else: else:
assert num_splits_index == 2 """[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\ intermediate_shape = input_shape[:-1] +\
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, num_splits) self.hidden_size_per_attention_head, num_splits)
...@@ -213,10 +217,10 @@ class ParallelSelfAttention(MegatronModule): ...@@ -213,10 +217,10 @@ class ParallelSelfAttention(MegatronModule):
if checkpoint_version is not None: if checkpoint_version is not None:
if checkpoint_version == 0: if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, 0) mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1: elif checkpoint_version == 1:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, 2) mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_x_layer.size()[:-1] + \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment