Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0bc75448
Commit
0bc75448
authored
Oct 12, 2020
by
Vijay Korthikanti
Browse files
address review comments
parent
9b0083ea
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
10 deletions
+14
-10
megatron/model/transformer.py
megatron/model/transformer.py
+14
-10
No files found.
megatron/model/transformer.py
View file @
0bc75448
...
@@ -172,14 +172,14 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -172,14 +172,14 @@ class ParallelSelfAttention(MegatronModule):
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
_transpose_last_dim
(
self
,
mixed_layer
,
num_splits
,
num_splits_index
):
def
_transpose_last_dim
(
self
,
mixed_layer
,
num_splits
,
num_splits_first
):
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
input_shape
=
mixed_layer
.
size
();
input_shape
=
mixed_layer
.
size
();
if
num_splits_index
==
0
:
if
num_splits_first
:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape
=
input_shape
[:
-
1
]
+
\
intermediate_shape
=
input_shape
[:
-
1
]
+
\
(
num_splits
,
self
.
num_attention_heads_per_partition
,
(
num_splits
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
self
.
hidden_size_per_attention_head
)
...
@@ -187,7 +187,11 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -187,7 +187,11 @@ class ParallelSelfAttention(MegatronModule):
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
transpose
(
-
2
,
-
3
).
contiguous
()
mixed_layer
=
mixed_layer
.
transpose
(
-
2
,
-
3
).
contiguous
()
else
:
else
:
assert
num_splits_index
==
2
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape
=
input_shape
[:
-
1
]
+
\
intermediate_shape
=
input_shape
[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
num_splits
)
self
.
hidden_size_per_attention_head
,
num_splits
)
...
@@ -213,10 +217,10 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -213,10 +217,10 @@ class ParallelSelfAttention(MegatronModule):
if
checkpoint_version
is
not
None
:
if
checkpoint_version
is
not
None
:
if
checkpoint_version
==
0
:
if
checkpoint_version
==
0
:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
0
)
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
True
)
elif
checkpoint_version
==
1
:
elif
checkpoint_version
==
1
:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
2
)
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
False
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment