Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
a8d47812
Commit
a8d47812
authored
Sep 29, 2020
by
Vijay Korthikanti
Browse files
checkpoint versioning
parent
7d4ad51e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
5 deletions
+19
-5
megatron/arguments.py
megatron/arguments.py
+0
-2
megatron/checkpointing.py
megatron/checkpointing.py
+14
-0
megatron/model/transformer.py
megatron/model/transformer.py
+5
-3
No files found.
megatron/arguments.py
View file @
a8d47812
...
...
@@ -313,8 +313,6 @@ def _add_checkpointing_args(parser):
help
=
'Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.'
)
group
.
add_argument
(
'--old-checkpoint-format'
,
action
=
'store_true'
,
help
=
'load old checkpoint format[Q[]K[]V[]].'
)
return
parser
...
...
megatron/checkpointing.py
View file @
a8d47812
...
...
@@ -27,6 +27,15 @@ from megatron import mpu, get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
_CHECKPOINT_VERSION
=
None
def
set_checkpoint_version
(
value
):
global
_CHECKPOINT_VERSION
_CHECKPOINT_VERSION
=
value
def
get_checkpoint_version
():
global
_CHECKPOINT_VERSION
return
_CHECKPOINT_VERSION
def
check_checkpoint_args
(
checkpoint_args
):
"""Ensure fixed arguments for a model are the same for the input
...
...
@@ -90,6 +99,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
1
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
...
...
@@ -184,6 +194,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print_rank_0
(
'could not load the checkpoint'
)
sys
.
exit
()
# set checkpoint version
set_checkpoint_version
(
state_dict
.
get
(
'checkpoint_version'
,
0
))
# Set iteration.
if
args
.
finetune
or
release
:
iteration
=
0
...
...
@@ -198,6 +211,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
'iteration from checkpoint {}, exiting'
.
format
(
checkpoint_name
))
sys
.
exit
()
# Check arguments.
if
'args'
in
state_dict
:
...
...
megatron/model/transformer.py
View file @
a8d47812
...
...
@@ -23,6 +23,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron.mpu
import
LayerNorm
from
megatron.module
import
MegatronModule
from
megatron.checkpointing
import
get_checkpoint_version
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
...
...
@@ -120,7 +121,6 @@ class ParallelSelfAttention(MegatronModule):
super
(
ParallelSelfAttention
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
old_checkpoint_format
=
args
.
old_checkpoint_format
self
.
attention_mask_func
=
attention_mask_func
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
...
...
@@ -178,7 +178,7 @@ class ParallelSelfAttention(MegatronModule):
input_shape
=
mixed_layer
.
size
();
last_dim
=
input_shape
[
-
1
]
assert
last_dim
%
3
==
0
assert
last_dim
%
3
==
0
,
"expected QKV dimension"
last_dim_split
=
last_dim
//
3
intermediate_shape
=
input_shape
[:
-
1
]
+
\
...
...
@@ -200,7 +200,9 @@ class ParallelSelfAttention(MegatronModule):
# Attention heads [s, b, hp] --> [s, b, hp * 3]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
if
self
.
old_checkpoint_format
:
checkpoint_version
=
get_checkpoint_version
()
if
checkpoint_version
is
not
None
and
\
checkpoint_version
==
0
:
# [s, b, 3 * hp] --> [s, b, hp * 3]
mixed_x_layer
=
self
.
_transpose_last_dim
(
mixed_x_layer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment