Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
cebd3b8b
Commit
cebd3b8b
authored
Dec 02, 2020
by
mohammad
Browse files
addrressed jareds comments
parent
f0a445fa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
19 deletions
+7
-19
megatron/checkpointing.py
megatron/checkpointing.py
+5
-13
megatron/training.py
megatron/training.py
+2
-6
No files found.
megatron/checkpointing.py
View file @
cebd3b8b
...
...
@@ -89,8 +89,7 @@ def get_checkpoint_tracker_filename(checkpoints_path):
return
os
.
path
.
join
(
checkpoints_path
,
'latest_checkpointed_iteration.txt'
)
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
consumed_train_samples
=
None
,
consumed_valid_samples
=
None
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
"""Save a model checkpoint."""
args
=
get_args
()
...
...
@@ -104,10 +103,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler,
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
2.0
state_dict
[
'iteration'
]
=
iteration
if
consumed_train_samples
:
state_dict
[
'consumed_train_samples'
]
=
consumed_train_samples
if
consumed_valid_samples
:
state_dict
[
'consumed_valid_samples'
]
=
consumed_valid_samples
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
# Optimizer stuff.
...
...
@@ -219,17 +214,14 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
checkpoint_name
))
sys
.
exit
()
if
'consumed_train_samples'
in
state_dict
:
assert
args
.
consumed_train_samples
==
0
args
.
consumed_train_samples
=
state_dict
[
'consumed_train_samples'
]
if
'consumed_valid_samples'
in
state_dict
:
assert
args
.
consumed_valid_samples
==
0
args
.
consumed_valid_samples
=
state_dict
[
'consumed_valid_samples'
]
# Check arguments.
assert
args
.
consumed_train_samples
==
0
assert
args
.
consumed_valid_samples
==
0
if
'args'
in
state_dict
:
checkpoint_args
=
state_dict
[
'args'
]
check_checkpoint_args
(
checkpoint_args
)
args
.
consumed_train_samples
=
getattr
(
args
,
'consumed_train_samples'
,
0
)
args
.
consumed_valid_samples
=
getattr
(
args
,
'consumed_valid_samples'
,
0
)
else
:
print_rank_0
(
'could not find arguments in the checkpoint ...'
)
...
...
megatron/training.py
View file @
cebd3b8b
...
...
@@ -104,9 +104,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration
,
False
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
consumed_train_samples
=
args
.
consumed_train_samples
,
consumed_valid_samples
=
args
.
consumed_valid_samples
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
if
args
.
do_test
:
# Run on test data.
...
...
@@ -438,9 +436,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Checkpointing
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
consumed_train_samples
=
args
.
consumed_train_samples
,
consumed_valid_samples
=
args
.
consumed_valid_samples
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment