Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9bee1191
Unverified
Commit
9bee1191
authored
Apr 01, 2022
by
ver217
Committed by
GitHub
Apr 01, 2022
Browse files
[hotfix] fix sharded optim zero grad (#604)
* fix sharded optim zero grad * polish comments
parent
297b8baa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
3 deletions
+21
-3
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+21
-3
No files found.
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
9bee1191
...
...
@@ -184,7 +184,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if
found_inf
:
self
.
_logger
.
warning
(
'found inf during ShardedOptimV2 step'
)
self
.
zero_grad
()
self
.
_
zero_grad
(
recover_data
=
True
)
return
self
.
_prepare_data
()
...
...
@@ -246,13 +246,31 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
optim_state
=
OptimState
.
UNSCALED
def
zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
_zero_grad
()
def
_zero_grad
(
self
,
recover_data
:
bool
=
False
):
"""zero grad and maybe recover fp16 params
When `reuse_fp16_shard` is enabled,
p.colo_attr.sharded_data_tensor stores grad here.
We have to recover them from fp32 params.
Args:
recover_data (bool, optional): Whether to recover fp16 param from fp32 param. Defaults to False.
"""
# We must set grad to None
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
# Because grad here is sharded
# But next backward pass will create a full grad first
# Which leads to wrong accumulation
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
# p.colo_attr.sharded_data_tensor stores grad now
# we have to recover fp16 param
reuse_fp16_shard
=
p
.
colo_attr
.
saved_grad
.
data_ptr
()
==
p
.
colo_attr
.
sharded_data_tensor
.
data_ptr
()
p
.
colo_attr
.
saved_grad
.
set_null
()
if
recover_data
and
reuse_fp16_shard
:
p
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
self
.
master_params
[
p
].
payload
.
half
(),
torch
.
cuda
.
current_device
()))
def
sync_grad
(
self
):
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment