Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9f64f5f4
Commit
9f64f5f4
authored
Mar 25, 2022
by
Lawrence McAfee
Browse files
working: checkpoint save/load.
parent
82491e4b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
43 deletions
+51
-43
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+47
-43
megatron/training.py
megatron/training.py
+4
-0
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
9f64f5f4
...
...
@@ -201,6 +201,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
model_param
in
group_range
[
"params"
]:
assert
model_param
.
requires_grad
model_index
,
dtype
=
param_gbuf_map
[
model_param
]
gbuf_range
=
model_gbuf_ranges
[
model_index
][
dtype
]
param_range
=
gbuf_range
[
"param_map"
][
model_param
][
"param"
]
...
...
@@ -310,50 +312,44 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return
None
# >>>
# def state_dict(self):
# state_dict = {}
# state_dict['optimizer'] = self.optimizer.state_dict()
# if self.grad_scaler:
# state_dict['grad_scaler'] = self.grad_scaler.state_dict()
# state_dict['groups'] = [g['params'] for g in self.optimizer.param_groups]
# return state_dict
def
state_dict
(
self
):
raise
Exception
(
"fix me."
)
# <<<
# >>>
# def load_state_dict(self, state_dict):
# # Optimizer.
# optimizer_key = 'optimizer'
# if optimizer_key not in state_dict:
# optimizer_key = 'optimizer_state_dict'
# print_rank_0('***WARNING*** loading optimizer from '
# 'an old checkpoint ...')
# self.optimizer.load_state_dict(state_dict[optimizer_key])
# # Grad scaler.
# if 'grad_scaler' not in state_dict:
# print_rank_0('***WARNING*** found an old checkpoint, will not '
# 'load grad scaler ...')
# else:
# if self.grad_scaler:
# self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# else:
# print_rank_0('***WARNING*** fould the grad scaler in the '
# 'checkpoint but it is None in the class. '
# 'Skipping loading grad scaler ...')
# # Copy data for the main params.
# current_groups = [ g["params"] for g in self.optimizer.param_groups ]
# assert "groups" in state_dict, "key 'groups' not in state_dict."
# for current_group, saved_group in zip(current_groups, state_dict["groups"]):
# for current_param, saved_param in zip(current_group, saved_group):
# current_param.data.copy_(saved_param.data)
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'shard_fp32_from_float16_groups'
]
=
\
self
.
shard_fp32_from_float16_groups
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
raise
Exception
(
"hi."
)
# <<<
# Optimizer.
optimizer_key
=
'optimizer'
if
optimizer_key
not
in
state_dict
:
optimizer_key
=
'optimizer_state_dict'
print_rank_0
(
'***WARNING*** loading optimizer from '
'an old checkpoint ...'
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
for
current_group
,
saved_group
in
zip
(
self
.
shard_fp32_from_float16_groups
,
state_dict
[
"shard_fp32_from_float16_groups"
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
def
zero_grad
(
self
,
set_to_none
=
True
):
...
...
@@ -362,11 +358,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
# >>>
# params = [ p for g in self.shard_fp32_groups for p in g ]
# pax(0, {
# "shard_fp32_groups" : self.shard_fp32_groups,
# "params" : params,
# "grads" : [ p.grad for p in params ],
# })
# <<<
for
groups
in
(
self
.
full_float16_groups
,
self
.
full_fp32_groups
,
self
.
shard_float16_groups
,
# grad empty/unused here?
self
.
shard_fp32_groups
,
self
.
shard_fp32_groups
,
# throws grad-access warning
self
.
shard_fp32_from_float16_groups
):
for
group
in
groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
...
...
megatron/training.py
View file @
9f64f5f4
...
...
@@ -52,6 +52,10 @@ from megatron.schedules import get_forward_backward_func
from
megatron.utils
import
report_memory
from
megatron.model.vision.knn_monitor
import
compute_feature_bank
# >>>
from
lutil
import
pax
,
tp
,
print_seq
# <<<
def
print_datetime
(
string
):
"""Note that this call will sync across all ranks."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment