Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
37ae6646
Commit
37ae6646
authored
Sep 28, 2020
by
Vijay Korthikanti
Browse files
support for old checkpoint format
parent
4bf923d5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
1 deletion
+23
-1
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/model/transformer.py
megatron/model/transformer.py
+21
-1
No files found.
megatron/arguments.py
View file @
37ae6646
...
@@ -313,6 +313,8 @@ def _add_checkpointing_args(parser):
...
@@ -313,6 +313,8 @@ def _add_checkpointing_args(parser):
help
=
'Load model for finetuning. Do not load optimizer '
help
=
'Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.'
)
'Assumed when loading a release checkpoint.'
)
group
.
add_argument
(
'--old-checkpoint-format'
,
action
=
'store_true'
,
help
=
'load old checkpoint format[Q[]K[]V[]].'
)
return
parser
return
parser
...
...
megatron/model/transformer.py
View file @
37ae6646
...
@@ -120,6 +120,7 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -120,6 +120,7 @@ class ParallelSelfAttention(MegatronModule):
super
(
ParallelSelfAttention
,
self
).
__init__
()
super
(
ParallelSelfAttention
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
fp16
=
args
.
fp16
self
.
old_checkpoint_format
=
args
.
old_checkpoint_format
self
.
attention_mask_func
=
attention_mask_func
self
.
attention_mask_func
=
attention_mask_func
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
...
@@ -170,7 +171,23 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -170,7 +171,23 @@ class ParallelSelfAttention(MegatronModule):
input_is_parallel
=
True
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
_transpose_last_dim
(
self
,
mixed_layer
):
"""[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose)
[s, b, hp, 3] -->(view) [s, b, 3 * hp] """
input_shape
=
mixed_layer
.
size
();
last_dim
=
input_shape
[
-
1
]
assert
last_dim
%
3
==
0
last_dim_split
=
last_dim
//
3
intermediate_shape
=
input_shape
[:
-
1
]
+
\
(
3
,
last_dim_split
)
mixed_layer
=
mixed_layer
.
view
(
*
intermediate_shape
)
mixed_layer
=
mixed_layer
.
transpose
(
-
1
,
-
2
).
contiguous
()
mixed_layer
=
mixed_layer
.
view
(
*
input_shape
)
return
mixed_layer
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
get_key_value
=
False
):
...
@@ -182,6 +199,9 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -182,6 +199,9 @@ class ParallelSelfAttention(MegatronModule):
# Attention heads [s, b, hp] --> [s, b, 3 * hp]
# Attention heads [s, b, hp] --> [s, b, 3 * hp]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
if
self
.
old_checkpoint_format
:
self
.
_transpose_last_dim
(
mixed_x_layer
)
# [s, b, 3 * hp] --> [s, b, np, 3 * hn]
# [s, b, 3 * hp] --> [s, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment