Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6abf39be
Commit
6abf39be
authored
Nov 03, 2020
by
Deepak Narayanan
Browse files
Only transpose hidden_states when necessary
parent
57c3b364
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
7 deletions
+7
-7
megatron/model/transformer.py
megatron/model/transformer.py
+6
-6
megatron/training.py
megatron/training.py
+1
-1
No files found.
megatron/model/transformer.py
View file @
6abf39be
...
...
@@ -552,7 +552,7 @@ class ParallelTransformer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
# Checks
# Checks
.
if
layer_past
is
not
None
:
assert
get_key_value
,
\
'for not None values in layer_past, '
\
...
...
@@ -562,8 +562,9 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with '
\
'activation checkpointing'
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
if
mpu
.
is_pipeline_first_stage
():
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
checkpoint_activations
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
...
...
@@ -584,11 +585,10 @@ class ParallelTransformer(MegatronModule):
hidden_states
,
present
=
hidden_states
presents
.
append
(
present
)
# reverting data format change [s b h] --> [b s h]
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
# Final layer norm.
if
mpu
.
is_pipeline_last_stage
():
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
else
:
output
=
hidden_states
...
...
megatron/training.py
View file @
6abf39be
...
...
@@ -245,7 +245,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
batch_size
,
args
.
seq_length
,
args
.
hidden_size
)
tensor_shape
=
(
args
.
seq_length
,
args
.
batch_size
,
args
.
hidden_size
)
if
recv_forward
:
tensor_recv_prev
=
torch
.
empty
(
tensor_shape
,
requires_grad
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment