Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9b0083ea
Commit
9b0083ea
authored
Oct 12, 2020
by
Vijay Korthikanti
Browse files
Incrementing checkpoint version to 2.0
parent
ee327acd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
11 deletions
+23
-11
megatron/checkpointing.py
megatron/checkpointing.py
+1
-1
megatron/model/transformer.py
megatron/model/transformer.py
+22
-10
No files found.
megatron/checkpointing.py
View file @
9b0083ea
...
...
@@ -101,7 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
1
.0
state_dict
[
'checkpoint_version'
]
=
2
.0
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
...
...
megatron/model/transformer.py
View file @
9b0083ea
...
...
@@ -172,19 +172,28 @@ class ParallelSelfAttention(MegatronModule):
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
def
_transpose_last_dim
(
self
,
mixed_layer
,
num_splits
):
def
_transpose_last_dim
(
self
,
mixed_layer
,
num_splits
,
num_splits_index
):
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
input_shape
=
mixed_layer
.
size
();
intermediate_shape
=
input_shape
[:
-
1
]
+
\
(
num_splits
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
if
num_splits_index
==
0
:
intermediate_shape
=
input_shape
[:
-
1
]
+
\
(
num_splits
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
transpose
(
-
2
,
-
3
).
contiguous
()
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
transpose
(
-
2
,
-
3
).
contiguous
()
else
:
assert
num_splits_index
==
2
intermediate_shape
=
input_shape
[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
num_splits
)
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
transpose
(
-
1
,
-
2
).
contiguous
()
mixed_layer
=
mixed_layer
.
view
(
*
input_shape
)
return
mixed_layer
...
...
@@ -201,10 +210,13 @@ class ParallelSelfAttention(MegatronModule):
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
is
not
None
and
\
checkpoint_version
==
0
:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
)
if
checkpoint_version
is
not
None
:
if
checkpoint_version
==
0
:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
0
)
elif
checkpoint_version
==
1
:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
,
3
,
2
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment