Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ee327acd
Commit
ee327acd
authored
Oct 12, 2020
by
Vijay Korthikanti
Browse files
reordering perf fix
parent
42d21122
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
20 deletions
+20
-20
megatron/model/transformer.py
megatron/model/transformer.py
+20
-20
No files found.
megatron/model/transformer.py
View file @
ee327acd
...
...
@@ -171,20 +171,20 @@ class ParallelSelfAttention(MegatronModule):
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
def
_transpose_last_dim
(
self
,
mixed_layer
):
"""[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose)
[s, b, hp, 3] -->(view) [s, b, 3 * hp] """
def
_transpose_last_dim
(
self
,
mixed_layer
,
num_splits
):
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
input_shape
=
mixed_layer
.
size
();
last_dim
=
input_shape
[
-
1
]
assert
last_dim
%
3
==
0
,
"expected QKV dimension"
last_dim_split
=
last_dim
//
3
intermediate_shape
=
input_shape
[:
-
1
]
+
\
(
3
,
last_dim_split
)
(
num_splits
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
transpose
(
-
1
,
-
2
).
contiguous
()
mixed_layer
=
mixed_layer
.
transpose
(
-
2
,
-
3
).
contiguous
()
mixed_layer
=
mixed_layer
.
view
(
*
input_shape
)
return
mixed_layer
...
...
@@ -197,25 +197,25 @@ class ParallelSelfAttention(MegatronModule):
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h
p
] --> [sq, b,
h
p * 3]
# Attention heads [sq, b, h] --> [sq, b,
(n
p * 3
* hn)
]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
is
not
None
and
\
checkpoint_version
==
0
:
# [s
q
, b, 3 *
hp
] --> [s
q
, b,
h
p * 3]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
)
# [s, b,
(
3 *
np * hn)
] --> [s, b,
(n
p * 3
* hn)
]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
)
# [sq, b,
h
p * 3] --> [sq, b, np,
hn, 3]
# [sq, b,
(n
p * 3
* hn)
] --> [sq, b, np,
3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
3
)
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, np,
hn, 3
] --> 3 [sq, b, np, hn]
query_layer
=
mixed_x_layer
[:,:,:,:,
0
]
key_layer
=
mixed_x_layer
[:,:,:,:,
1
]
value_layer
=
mixed_x_layer
[:,:,:,:,
2
]
# [sq, b, np,
3 * hn
] --> 3 [sq, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
# ==================================
# Adjust key and value for inference
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment