Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b178e6fc
Commit
b178e6fc
authored
Jun 08, 2022
by
Lawrence McAfee
Browse files
error fixes & tested.
parent
977efdfb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
20 deletions
+30
-20
megatron/checkpointing.py
megatron/checkpointing.py
+30
-20
No files found.
megatron/checkpointing.py
View file @
b178e6fc
...
@@ -194,47 +194,57 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
...
@@ -194,47 +194,57 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_checkpoint_name
,
optim_checkpoint_name
=
\
model_checkpoint_name
,
optim_checkpoint_name
=
\
get_checkpoint_names
(
args
.
save
,
iteration
,
args
.
use_distributed_optimizer
)
get_checkpoint_names
(
args
.
save
,
iteration
,
args
.
use_distributed_optimizer
)
# Save args, model, RNG.
# Collect args, model, RNG.
model_state_dict
=
{}
if
not
torch
.
distributed
.
is_initialized
()
\
if
not
torch
.
distributed
.
is_initialized
()
\
or
mpu
.
get_data_parallel_rank
()
==
0
:
or
mpu
.
get_data_parallel_rank
()
==
0
:
# Arguments, iteration, and model.
# Arguments, iteration, and model.
state_dict
=
{}
model_state_dict
[
'args'
]
=
args
state_dict
[
'args'
]
=
args
model_state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'checkpoint_version'
]
=
3.0
model_state_dict
[
'iteration'
]
=
iteration
state_dict
[
'iteration'
]
=
iteration
if
len
(
model
)
==
1
:
if
len
(
model
)
==
1
:
state_dict
[
'model'
]
=
model
[
0
].
state_dict_for_save_checkpoint
()
model_
state_dict
[
'model'
]
=
model
[
0
].
state_dict_for_save_checkpoint
()
else
:
else
:
for
i
in
range
(
len
(
model
)):
for
i
in
range
(
len
(
model
)):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
state_dict
[
'model%d'
%
i
]
=
model
[
i
].
state_dict_for_save_checkpoint
()
model_state_dict
[
'model%d'
%
i
]
=
\
model
[
i
].
state_dict_for_save_checkpoint
()
# RNG states.
# RNG states.
if
not
args
.
no_save_rng
:
if
not
args
.
no_save_rng
:
state_dict
[
"rng_state"
]
=
rng_state
model_
state_dict
[
"rng_state"
]
=
rng_state
# Save.
# Collect optimizer state. (Optimizer is saved separately from the model, due
ensure_directory_exists
(
model_checkpoint_name
)
torch
.
save
(
state_dict
,
model_checkpoint_name
)
# Save optimizer state. (Optimizer is saved separately from the model, due
# to the conflicting data pattern when using the distributed optimizer.)
# to the conflicting data pattern when using the distributed optimizer.)
optim_state_dict
=
{}
if
not
args
.
no_save_optim
\
if
not
args
.
no_save_optim
\
and
(
not
torch
.
distributed
.
is_initialized
()
and
(
not
torch
.
distributed
.
is_initialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
or
mpu
.
get_data_parallel_rank
()
==
0
or
args
.
use_distributed_optimizer
):
or
args
.
use_distributed_optimizer
):
# Optimizer stuff.
# Optimizer stuff.
state_dict
=
{}
if
optimizer
is
not
None
:
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
optim_
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
if
opt_param_scheduler
is
not
None
:
if
opt_param_scheduler
is
not
None
:
state_dict
[
'opt_param_scheduler'
]
=
opt_param_scheduler
.
state_dict
()
optim_state_dict
[
'opt_param_scheduler'
]
=
\
opt_param_scheduler
.
state_dict
()
# Save.
# Save.
if
args
.
use_distributed_optimizer
:
# Save model separate from optimizer.
if
model_state_dict
:
ensure_directory_exists
(
model_checkpoint_name
)
torch
.
save
(
model_state_dict
,
model_checkpoint_name
)
if
optim_state_dict
:
ensure_directory_exists
(
optim_checkpoint_name
)
ensure_directory_exists
(
optim_checkpoint_name
)
torch
.
save
(
state_dict
,
optim_checkpoint_name
)
torch
.
save
(
optim_state_dict
,
optim_checkpoint_name
)
else
:
# Save model and optimizer together.
state_dict
=
{
**
model_state_dict
,
**
optim_state_dict
}
if
state_dict
:
# only saves if populated (i.e., inherits conditions above)
ensure_directory_exists
(
model_checkpoint_name
)
torch
.
save
(
state_dict
,
model_checkpoint_name
)
# Wait so everyone is done (necessary)
# Wait so everyone is done (necessary)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment