Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6b43c789
Unverified
Commit
6b43c789
authored
Jul 21, 2022
by
ver217
Committed by
GitHub
Jul 21, 2022
Browse files
fix zero optim backward_by_grad and save/load (#1353)
parent
d068af81
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
0 deletions
+17
-0
colossalai/zero/zero_optimizer.py
colossalai/zero/zero_optimizer.py
+17
-0
No files found.
colossalai/zero/zero_optimizer.py
View file @
6b43c789
...
...
@@ -142,6 +142,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def
clip_grad_norm
(
self
,
model
:
torch
.
nn
.
Module
,
max_norm
:
float
):
if
self
.
optim_state
==
OptimState
.
SCALED
:
self
.
_unscale_grads
()
# TODO(ver217): fix zero clip grad norm
return
super
().
clip_grad_norm
(
model
,
max_norm
)
def
backward
(
self
,
loss
:
torch
.
Tensor
):
...
...
@@ -150,6 +151,11 @@ class ZeroOptimizer(ColossalaiOptimizer):
self
.
module
.
backward
(
loss
)
def
backward_by_grad
(
self
,
tensor
:
torch
.
Tensor
,
grad
:
torch
.
Tensor
):
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self
.
optim_state
=
OptimState
.
SCALED
self
.
module
.
backward_by_grad
(
tensor
,
grad
)
def
_maybe_move_fp32_params
(
self
):
...
...
@@ -184,7 +190,18 @@ class ZeroOptimizer(ColossalaiOptimizer):
if
isinstance
(
val
,
torch
.
Tensor
):
self
.
chunk_manager
.
add_extern_static_tensor
(
val
)
def
state_dict
(
self
):
optim_state_dict
=
super
().
state_dict
()
scaler_state_dict
=
self
.
grad_scaler
.
state_dict
()
optim_state_dict
[
'scaler'
]
=
scaler_state_dict
return
optim_state_dict
def
load_state_dict
(
self
,
*
args
,
**
kwargs
):
if
'scaler'
not
in
args
[
0
]:
self
.
_logger
.
warning
(
'Missing scaler when loading optimizer state dict'
,
ranks
=
[
0
])
else
:
scaler_state_dict
=
args
[
0
].
pop
(
'scaler'
)
self
.
grad_scaler
.
load_state_dict
(
scaler_state_dict
)
super
().
load_state_dict
(
*
args
,
**
kwargs
)
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment