Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
09d38884
Commit
09d38884
authored
Jan 14, 2022
by
Vijay Korthikanti
Browse files
checkpointing rng_state of all data parallel ranks
parent
343dc97a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
15 deletions
+52
-15
megatron/checkpointing.py
megatron/checkpointing.py
+52
-15
No files found.
megatron/checkpointing.py
View file @
09d38884
...
@@ -140,6 +140,32 @@ def read_metadata(tracker_filename):
...
@@ -140,6 +140,32 @@ def read_metadata(tracker_filename):
return
max_iter
,
release
return
max_iter
,
release
def
get_rng_state
():
""" collect rng state across data parallel ranks """
rng_state
=
{
'random_rng_state'
:
random
.
getstate
(),
'np_rng_state'
:
np
.
random
.
get_state
(),
'torch_rng_state'
:
torch
.
get_rng_state
(),
'cuda_rng_state'
:
torch
.
cuda
.
get_rng_state
(),
'rng_tracker_states'
:
mpu
.
get_cuda_rng_tracker
().
get_states
()}
rng_state_list
=
None
if
torch
.
distributed
.
is_initialized
()
and
\
mpu
.
get_data_parallel_world_size
()
>
1
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
rng_state_list
=
\
[
None
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
gather_object
(
rng_state
,
rng_state_list
,
dst
=
mpu
.
get_data_parallel_src_rank
(),
group
=
mpu
.
get_data_parallel_group
())
else
:
rng_state_list
=
[
rng_state
]
return
rng_state_list
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
"""Save a model checkpoint."""
"""Save a model checkpoint."""
args
=
get_args
()
args
=
get_args
()
...
@@ -150,6 +176,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -150,6 +176,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
))
iteration
,
args
.
save
))
# collect rng state across data parallel ranks
rng_state
=
get_rng_state
()
if
not
torch
.
distributed
.
is_initialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
not
torch
.
distributed
.
is_initialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
# Arguments, iteration, and model.
# Arguments, iteration, and model.
...
@@ -173,12 +202,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -173,12 +202,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# RNG states.
# RNG states.
if
not
args
.
no_save_rng
:
if
not
args
.
no_save_rng
:
state_dict
[
'random_rng_state'
]
=
random
.
getstate
()
state_dict
[
"rng_state"
]
=
rng_state
state_dict
[
'np_rng_state'
]
=
np
.
random
.
get_state
()
state_dict
[
'torch_rng_state'
]
=
torch
.
get_rng_state
()
state_dict
[
'cuda_rng_state'
]
=
torch
.
cuda
.
get_rng_state
()
state_dict
[
'rng_tracker_states'
]
\
=
mpu
.
get_cuda_rng_tracker
().
get_states
()
# Save.
# Save.
checkpoint_name
=
get_checkpoint_name
(
args
.
save
,
iteration
)
checkpoint_name
=
get_checkpoint_name
(
args
.
save
,
iteration
)
...
@@ -381,15 +405,28 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -381,15 +405,28 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# rng states.
# rng states.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_rng
:
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_rng
:
try
:
try
:
random
.
setstate
(
state_dict
[
'random_rng_state'
])
if
'rng_state'
in
state_dict
:
np
.
random
.
set_state
(
state_dict
[
'np_rng_state'
])
# access rng_state for data parallel rank
torch
.
set_rng_state
(
state_dict
[
'torch_rng_state'
])
rng_state
=
state_dict
[
'rng_state'
][
mpu
.
get_data_parallel_rank
()]
torch
.
cuda
.
set_rng_state
(
state_dict
[
'cuda_rng_state'
])
random
.
setstate
(
rng_state
[
'random_rng_state'
])
# Check for empty states array
np
.
random
.
set_state
(
rng_state
[
'np_rng_state'
])
if
not
state_dict
[
'rng_tracker_states'
]:
torch
.
set_rng_state
(
rng_state
[
'torch_rng_state'
])
raise
KeyError
torch
.
cuda
.
set_rng_state
(
rng_state
[
'cuda_rng_state'
])
mpu
.
get_cuda_rng_tracker
().
set_states
(
# Check for empty states array
state_dict
[
'rng_tracker_states'
])
if
not
rng_state
[
'rng_tracker_states'
]:
raise
KeyError
mpu
.
get_cuda_rng_tracker
().
set_states
(
state_dict
[
'rng_tracker_states'
])
else
:
# backward compatability
random
.
setstate
(
state_dict
[
'random_rng_state'
])
np
.
random
.
set_state
(
state_dict
[
'np_rng_state'
])
torch
.
set_rng_state
(
state_dict
[
'torch_rng_state'
])
torch
.
cuda
.
set_rng_state
(
state_dict
[
'cuda_rng_state'
])
# Check for empty states array
if
not
state_dict
[
'rng_tracker_states'
]:
raise
KeyError
mpu
.
get_cuda_rng_tracker
().
set_states
(
state_dict
[
'rng_tracker_states'
])
except
KeyError
:
except
KeyError
:
print_rank_0
(
'Unable to load rng state from checkpoint {}. '
print_rank_0
(
'Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'Specify --no-load-rng or --finetune to prevent '
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment