Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ce470ba3
Unverified
Commit
ce470ba3
authored
Jul 21, 2022
by
ver217
Committed by
GitHub
Jul 21, 2022
Browse files
[checkpoint] sharded optim save/load grad scaler (#1350)
parent
05fae1fd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
0 deletions
+11
-0
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+11
-0
No files found.
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
ce470ba3
...
...
@@ -363,7 +363,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
master_params
[
p
].
trans_state
(
TensorState
.
HOLD
)
def
state_dict
(
self
):
optim_state_dict
=
super
().
state_dict
()
scaler_state_dict
=
self
.
grad_scaler
.
state_dict
()
optim_state_dict
[
'scaler'
]
=
scaler_state_dict
return
optim_state_dict
def
load_state_dict
(
self
,
*
args
,
**
kwargs
):
if
'scaler'
not
in
args
[
0
]:
self
.
_logger
.
warning
(
'Missing scaler when loading optimizer state dict'
,
ranks
=
[
0
])
else
:
scaler_state_dict
=
args
[
0
].
pop
(
'scaler'
)
self
.
grad_scaler
.
load_state_dict
(
scaler_state_dict
)
super
().
load_state_dict
(
*
args
,
**
kwargs
)
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment