Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
37ca7859
"src/vscode:/vscode.git/clone" did not exist on "53bc30dd45cf2e4f85bd42c9b69c50347d6bde6d"
Commit
37ca7859
authored
Mar 14, 2022
by
Lawrence McAfee
Browse files
started updating checkpointing.py.
parent
f6811e28
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
139 additions
and
30 deletions
+139
-30
megatron/checkpointing.py
megatron/checkpointing.py
+82
-25
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+57
-5
No files found.
megatron/checkpointing.py
View file @
37ca7859
...
@@ -81,7 +81,26 @@ def ensure_directory_exists(filename):
...
@@ -81,7 +81,26 @@ def ensure_directory_exists(filename):
os
.
makedirs
(
dirname
)
os
.
makedirs
(
dirname
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
# >>
# def get_checkpoint_name(checkpoints_path, iteration,
# release=False):
# """A unified checkpoint name."""
# if release:
# directory = 'release'
# else:
# directory = 'iter_{:07d}'.format(iteration)
# # Use both the tensor and pipeline MP rank.
# if mpu.get_pipeline_model_parallel_world_size() == 1:
# return os.path.join(checkpoints_path, directory,
# 'mp_rank_{:02d}'.format(
# mpu.get_tensor_model_parallel_rank()),
# 'model_optim_rng.pt')
# return os.path.join(checkpoints_path, directory,
# 'mp_rank_{:02d}_{:03d}'.format(
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank()),
# 'model_optim_rng.pt')
def
get_checkpoint_names
(
checkpoints_path
,
iteration
,
release
=
False
):
release
=
False
):
"""A unified checkpoint name."""
"""A unified checkpoint name."""
if
release
:
if
release
:
...
@@ -89,16 +108,17 @@ def get_checkpoint_name(checkpoints_path, iteration,
...
@@ -89,16 +108,17 @@ def get_checkpoint_name(checkpoints_path, iteration,
else
:
else
:
directory
=
'iter_{:07d}'
.
format
(
iteration
)
directory
=
'iter_{:07d}'
.
format
(
iteration
)
# Use both the tensor and pipeline MP rank.
# Use both the tensor and pipeline MP rank.
if
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
:
common_path
=
os
.
path
.
join
(
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
checkpoints_path
,
'mp_rank_{:02d}'
.
format
(
directory
,
mpu
.
get_tensor_model_parallel_rank
()),
"mp_rank_%02d_%03d_%03d"
%
(
'model_optim_rng.pt'
)
mpu
.
get_tensor_model_parallel_rank
(),
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
mpu
.
get_pipeline_model_parallel_rank
(),
'mp_rank_{:02d}_{:03d}'
.
format
(
mpu
.
get_data_parallel_rank
()))
mpu
.
get_tensor_model_parallel_rank
(),
model_name
=
os
.
path
.
join
(
common_path
,
"model_rng.pt"
)
mpu
.
get_pipeline_model_parallel_rank
()),
optim_name
=
os
.
path
.
join
(
common_path
,
"optim.pt"
)
'model_optim_rng.pt'
)
return
model_name
,
optim_name
# <<<
def
get_checkpoint_tracker_filename
(
checkpoints_path
):
def
get_checkpoint_tracker_filename
(
checkpoints_path
):
...
@@ -177,10 +197,16 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
...
@@ -177,10 +197,16 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
))
iteration
,
args
.
save
))
#
c
ollect rng state across data parallel ranks
#
C
ollect rng state across data parallel ranks
.
rng_state
=
get_rng_state
()
rng_state
=
get_rng_state
()
if
not
torch
.
distributed
.
is_initialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
# Checkpoint file names.
model_checkpoint_name
,
optim_checkpoint_name
=
\
get_checkpoint_names
(
args
.
save
,
iteration
)
# Save args, model, RNG.
if
not
torch
.
distributed
.
is_initialized
()
\
or
mpu
.
get_data_parallel_rank
()
==
0
:
# Arguments, iteration, and model.
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
=
{}
...
@@ -194,21 +220,49 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
...
@@ -194,21 +220,49 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
state_dict
[
'model%d'
%
i
]
=
model
[
i
].
state_dict_for_save_checkpoint
()
state_dict
[
'model%d'
%
i
]
=
model
[
i
].
state_dict_for_save_checkpoint
()
# Optimizer stuff.
# >>>
if
not
args
.
no_save_optim
:
# # Optimizer stuff.
if
optimizer
is
not
None
:
# if not args.no_save_optim:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
# if optimizer is not None:
if
opt_param_scheduler
is
not
None
:
# state_dict['optimizer'] = optimizer.state_dict()
state_dict
[
'opt_param_scheduler'
]
=
opt_param_scheduler
.
state_dict
()
# if opt_param_scheduler is not None:
# state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# <<<
# RNG states.
# RNG states.
if
not
args
.
no_save_rng
:
if
not
args
.
no_save_rng
:
state_dict
[
"rng_state"
]
=
rng_state
state_dict
[
"rng_state"
]
=
rng_state
# Save.
# Save.
checkpoint_name
=
get_checkpoint_name
(
args
.
save
,
iteration
)
ensure_directory_exists
(
model_checkpoint_name
)
ensure_directory_exists
(
checkpoint_name
)
torch
.
save
(
state_dict
,
model_checkpoint_name
)
torch
.
save
(
state_dict
,
checkpoint_name
)
# >>>
# Save optimizer state.
if
not
args
.
no_save_optim
\
and
(
not
torch
.
distributed
.
is_initialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
or
args
.
use_distributed_optimizer
):
# Optimizer stuff.
state_dict
=
{}
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
if
opt_param_scheduler
is
not
None
:
state_dict
[
'opt_param_scheduler'
]
=
opt_param_scheduler
.
state_dict
()
# Save.
ensure_directory_exists
(
optim_checkpoint_name
)
torch
.
save
(
state_dict
,
optim_checkpoint_name
)
# >>>
# from lutil import pax
# pax({
# "model_checkpoint_name" : model_checkpoint_name,
# "optim_checkpoint_name" : optim_checkpoint_name,
# "state_dict" : state_dict,
# })
# <<<
# <<<
# Wait so everyone is done (necessary)
# Wait so everyone is done (necessary)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
...
@@ -322,12 +376,14 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -322,12 +376,14 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
iteration
,
release
=
read_metadata
(
tracker_filename
)
iteration
,
release
=
read_metadata
(
tracker_filename
)
# Checkpoint.
# Checkpoint.
checkpoint_name
=
get_checkpoint_name
(
load_dir
,
iteration
,
release
)
model_checkpoint_name
,
optim_checkpoint_name
=
\
get_checkpoint_names
(
load_dir
,
iteration
,
release
)
print_rank_0
(
f
' loading checkpoint from
{
args
.
load
}
at iteration
{
iteration
}
'
)
print_rank_0
(
f
' loading checkpoint from
{
args
.
load
}
at iteration
{
iteration
}
'
)
# Load the checkpoint.
# Load the checkpoint.
try
:
try
:
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
model_state_dict
=
torch
.
load
(
model_checkpoint_name
,
map_location
=
'cpu'
)
optim_state_dict
=
torch
.
load
(
optim_checkpoint_name
,
map_location
=
'cpu'
)
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
from
megatron.fp16_deprecated
import
loss_scaler
from
megatron.fp16_deprecated
import
loss_scaler
# For backward compatibility.
# For backward compatibility.
...
@@ -336,7 +392,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -336,7 +392,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
'megatron.fp16_deprecated.loss_scaler'
]
'megatron.fp16_deprecated.loss_scaler'
]
sys
.
modules
[
'megatron.fp16.loss_scaler'
]
=
sys
.
modules
[
sys
.
modules
[
'megatron.fp16.loss_scaler'
]
=
sys
.
modules
[
'megatron.fp16_deprecated.loss_scaler'
]
'megatron.fp16_deprecated.loss_scaler'
]
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
model_state_dict
=
torch
.
load
(
model_checkpoint_name
,
map_location
=
'cpu'
)
optim_state_dict
=
torch
.
load
(
optim_checkpoint_name
,
map_location
=
'cpu'
)
sys
.
modules
.
pop
(
'fp16.loss_scaler'
,
None
)
sys
.
modules
.
pop
(
'fp16.loss_scaler'
,
None
)
sys
.
modules
.
pop
(
'megatron.fp16.loss_scaler'
,
None
)
sys
.
modules
.
pop
(
'megatron.fp16.loss_scaler'
,
None
)
except
BaseException
as
e
:
except
BaseException
as
e
:
...
...
megatron/optimizer/distrib_optimizer.py
View file @
37ca7859
...
@@ -295,12 +295,64 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -295,12 +295,64 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
get_main_grad
(
self
,
group_index
):
def
get_main_grad
(
self
,
group_index
):
return
self
.
get_main_param
(
group_index
).
grad
return
self
.
get_main_param
(
group_index
).
grad
def
load_state_dict
(
self
):
# def load_state_dict(self):
raise
Exception
(
"hi."
)
# raise Exception("hi.")
def
reload_model_params
(
self
):
# # def reload_model_params(self): # ... done in MixedPrecisionOptimizer
raise
Exception
(
"hi."
)
# # raise Exception("hi.")
# def state_dict(self):
# raise Exception("hi.")
def
state_dict
(
self
):
def
state_dict
(
self
):
raise
Exception
(
"hi."
)
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'params'
]
=
\
[
p
for
g
in
self
.
optimizer
.
param_groups
for
p
in
g
[
"params"
]
]
# pax(0, { # ... only called on model rank 0
# # "optimizer" : self.optimizer,
# "state_dict" : state_dict,
# "state_dict / param_groups" : state_dict["optimizer"]["param_groups"],
# "optimizer / groups" : self.optimizer.param_groups,
# "state_dict / params" : [ p.shape for p in state_dict["params"] ],
# "optimizer / params" :
# [ p.shape for g in self.optimizer.param_groups for p in g["params"] ],
# })
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
# Optimizer.
optimizer_key
=
'optimizer'
if
optimizer_key
not
in
state_dict
:
optimizer_key
=
'optimizer_state_dict'
print_rank_0
(
'***WARNING*** loading optimizer from '
'an old checkpoint ...'
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
pax
(
0
,
{
"state_dict"
:
state_dict
,
"params"
:
state_dict
[
"params"
],
})
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
params_key
=
'params'
assert
params_key
in
state_dict
,
"key 'params' not in state_dict."
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_float16_groups
,
state_dict
[
fp32_from_float16_params_key
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment