Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a45ddf2d
Unverified
Commit
a45ddf2d
authored
Jul 08, 2022
by
ver217
Committed by
GitHub
Jul 08, 2022
Browse files
[hotfix] fix sharded optim step and clip_grad_norm (#1226)
parent
f071b500
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
4 deletions
+10
-4
colossalai/utils/common.py
colossalai/utils/common.py
+2
-2
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+8
-2
No files found.
colossalai/utils/common.py
View file @
a45ddf2d
...
...
@@ -195,7 +195,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Make sure the grads are in fp32
assert
param
.
grad
.
dtype
==
torch
.
float
,
\
f
'expected gradient to be dtype torch.float, but got
{
param
.
grad
.
type
()
}
'
if
hasattr
(
param
,
'
zero_
is_sharded
'
)
:
if
hasattr
(
param
,
'
colo_attr'
)
and
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
:
has_zero_shared_param
=
True
params
.
append
(
param
)
...
...
@@ -234,7 +234,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if
is_model_parallel_parameter
(
p
):
reductor
=
(
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
/
getattr
(
p
,
NUM_PARTITIONS
))
**
(
1
/
norm_type
)
tensor_parallel_grads
.
append
(
p
.
grad
.
data
/
reductor
)
elif
hasattr
(
p
,
'
zero_
is_sharded
'
)
:
elif
hasattr
(
p
,
'
colo_attr'
)
and
p
.
colo_attr
.
sharded_data_tensor
.
is_sharded
:
zero_sharded_grads
.
append
(
p
.
grad
.
data
)
else
:
no_tensor_parallel_grads
.
append
(
p
.
grad
.
data
)
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
a45ddf2d
...
...
@@ -169,21 +169,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
model
.
backward
(
loss
)
def
backward_by_grad
(
self
,
tensor
:
Tensor
,
grad
:
Tensor
)
->
None
:
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self
.
optim_state
=
OptimState
.
SCALED
self
.
model
.
backward_by_grad
(
tensor
,
grad
)
def
clip_grad_norm
(
self
,
model
:
nn
.
Module
,
max_norm
:
float
):
if
self
.
optim_state
==
OptimState
.
SCALED
:
self
.
_prepare_grads
()
self
.
_unscale_grads
()
return
super
().
clip_grad_norm
(
model
,
max_norm
)
def
step
(
self
,
*
args
,
**
kwargs
):
self
.
_prepare_grads
()
self
.
_maybe_move_fp32_shards
()
# unscale grads if scaled
if
self
.
optim_state
==
OptimState
.
SCALED
:
self
.
_prepare_grads
()
self
.
_unscale_grads
()
self
.
_maybe_move_fp32_shards
()
found_inf
=
self
.
_check_overflow
()
self
.
grad_scaler
.
update
(
found_inf
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment