Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
1b2db724
Commit
1b2db724
authored
Mar 18, 2022
by
Jared Casper
Browse files
Addressing comments.
parent
5a1aa663
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
10 deletions
+10
-10
megatron/checkpointing.py
megatron/checkpointing.py
+10
-10
No files found.
megatron/checkpointing.py
View file @
1b2db724
...
@@ -83,20 +83,20 @@ def ensure_directory_exists(filename):
...
@@ -83,20 +83,20 @@ def ensure_directory_exists(filename):
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
,
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
,
pipeline_parallel
_size
=
None
,
tensor_rank
=
None
,
pipeline_rank
=
None
):
pipeline_parallel
=
None
,
tensor_rank
=
None
,
pipeline_rank
=
None
):
"""
A unified
checkpoint
name
."""
"""
Determine the directory name for this rank's
checkpoint."""
if
release
:
if
release
:
directory
=
'release'
directory
=
'release'
else
:
else
:
directory
=
'iter_{:07d}'
.
format
(
iteration
)
directory
=
'iter_{:07d}'
.
format
(
iteration
)
# Use both the tensor and pipeline MP rank.
# Use both the tensor and pipeline MP rank.
if
pipeline_parallel
_size
is
None
:
if
pipeline_parallel
is
None
:
pipeline_parallel
_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel
=
(
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
)
if
tensor_rank
is
None
:
if
tensor_rank
is
None
:
tensor_rank
=
mpu
.
get_tensor_model_parallel_rank
()
tensor_rank
=
mpu
.
get_tensor_model_parallel_rank
()
if
pipeline_rank
is
None
:
if
pipeline_rank
is
None
:
pipeline_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
pipeline_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
if
pipeline_parallel
_size
==
1
:
if
not
pipeline_parallel
:
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
f
'mp_rank_
{
tensor_rank
:
02
d
}
'
,
f
'mp_rank_
{
tensor_rank
:
02
d
}
'
,
'model_optim_rng.pt'
)
'model_optim_rng.pt'
)
...
@@ -116,14 +116,14 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
...
@@ -116,14 +116,14 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
# Look for checkpoint with no pipelining
# Look for checkpoint with no pipelining
filename
=
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
,
filename
=
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
,
pipeline_parallel
_size
=
1
,
pipeline_parallel
=
False
,
tensor_rank
=
0
,
pipeline_rank
=
0
)
tensor_rank
=
0
,
pipeline_rank
=
0
)
if
os
.
path
.
isfile
(
filename
):
if
os
.
path
.
isfile
(
filename
):
return
filename
return
filename
# Look for checkpoint with pipelining
# Look for checkpoint with pipelining
filename
=
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
,
filename
=
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
,
pipeline_parallel
_size
=
2
,
pipeline_parallel
=
True
,
tensor_rank
=
0
,
pipeline_rank
=
0
)
tensor_rank
=
0
,
pipeline_rank
=
0
)
if
os
.
path
.
isfile
(
filename
):
if
os
.
path
.
isfile
(
filename
):
return
filename
return
filename
...
@@ -404,7 +404,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
...
@@ -404,7 +404,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
if
load_dir
is
None
:
if
load_dir
is
None
:
return
args
return
args
state_dict
,
release
=
_load_base_checkpoint
(
load_dir
,
True
)
state_dict
,
release
=
_load_base_checkpoint
(
load_dir
,
rank0
=
True
)
if
not
state_dict
:
if
not
state_dict
:
return
args
return
args
...
@@ -460,7 +460,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -460,7 +460,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model
=
unwrap_model
(
model
)
model
=
unwrap_model
(
model
)
state_dict
,
release
=
_load_base_checkpoint
(
load_dir
,
False
)
state_dict
,
release
=
_load_base_checkpoint
(
load_dir
,
rank0
=
False
)
# set checkpoint version
# set checkpoint version
set_checkpoint_version
(
state_dict
.
get
(
'checkpoint_version'
,
0
))
set_checkpoint_version
(
state_dict
.
get
(
'checkpoint_version'
,
0
))
...
@@ -587,7 +587,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
...
@@ -587,7 +587,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
with
open
(
tracker_filename
,
'r'
)
as
f
:
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
iteration
=
int
(
f
.
read
().
strip
())
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
release
=
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment