Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
70814dc2
"...git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "80e8c1d3d29e7c1e8f3bd9eb02b3c99e874bdb9f"
Commit
70814dc2
authored
Mar 03, 2022
by
ver217
Committed by
Frank Lee
Mar 11, 2022
Browse files
fix master params dtype
parent
795210dd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
9 deletions
+9
-9
colossalai/zero/sharded_optim/sharded_adam.py
colossalai/zero/sharded_optim/sharded_adam.py
+9
-9
No files found.
colossalai/zero/sharded_optim/sharded_adam.py
View file @
70814dc2
...
@@ -26,7 +26,7 @@ class ShardedAdam(ColossalaiOptimizer):
...
@@ -26,7 +26,7 @@ class ShardedAdam(ColossalaiOptimizer):
def
__init__
(
self
,
def
__init__
(
self
,
adam_optim
:
Optimizer
,
adam_optim
:
Optimizer
,
sharded_model
:
nn
.
Module
,
sharded_model
:
Union
[
nn
.
Module
,
ShardedModelV2
],
cpu_offload
:
bool
=
False
,
cpu_offload
:
bool
=
False
,
initial_scale
:
float
=
2
**
32
,
initial_scale
:
float
=
2
**
32
,
min_scale
:
float
=
1
,
min_scale
:
float
=
1
,
...
@@ -61,9 +61,11 @@ class ShardedAdam(ColossalaiOptimizer):
...
@@ -61,9 +61,11 @@ class ShardedAdam(ColossalaiOptimizer):
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
if
hasattr
(
p
,
'ca_attr'
):
if
hasattr
(
p
,
'ca_attr'
):
assert
p
.
ca_attr
.
is_sharded
,
'ShardedAdam can be only used with sharded model'
assert
p
.
ca_attr
.
is_sharded
,
'ShardedAdam can be only used with sharded model'
self
.
master_params
[
p
]
=
p
.
ca_attr
.
payload
(
self
.
device
)
.
to
(
torch
.
float
)
self
.
master_params
[
p
]
=
p
.
ca_attr
.
payload
(
self
.
device
)
else
:
else
:
self
.
master_params
[
p
]
=
p
.
data
.
to
(
torch
.
float
)
self
.
master_params
[
p
]
=
p
.
data
.
to
(
device
=
self
.
device
)
if
torch
.
is_floating_point
(
self
.
master_params
[
p
])
and
self
.
master_params
[
p
].
dtype
!=
torch
.
float
:
self
.
master_params
[
p
]
=
self
.
master_params
[
p
].
to
(
torch
.
float
)
def
step
(
self
,
*
args
,
**
kwargs
):
def
step
(
self
,
*
args
,
**
kwargs
):
# unscale grads if scaled
# unscale grads if scaled
...
@@ -85,8 +87,9 @@ class ShardedAdam(ColossalaiOptimizer):
...
@@ -85,8 +87,9 @@ class ShardedAdam(ColossalaiOptimizer):
# Write master param to payload and set p.data to None
# Write master param to payload and set p.data to None
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
# TODO: update payload
if
hasattr
(
p
,
'ca_attr'
):
p
.
data
=
None
# TODO: update payload
p
.
data
=
None
return
ret
return
ret
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
...
@@ -129,10 +132,7 @@ class ShardedAdam(ColossalaiOptimizer):
...
@@ -129,10 +132,7 @@ class ShardedAdam(ColossalaiOptimizer):
# all-reduce over model parallel group
# all-reduce over model parallel group
dist
.
all_reduce
(
self
.
_found_overflow
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
mp_process_group
)
dist
.
all_reduce
(
self
.
_found_overflow
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
mp_process_group
)
if
self
.
_found_overflow
.
item
()
>
0
:
return
self
.
_found_overflow
.
item
()
>
0
return
True
else
:
return
False
def
_unscale_grads
(
self
):
def
_unscale_grads
(
self
):
assert
self
.
optim_state
==
OptimState
.
SCALED
assert
self
.
optim_state
==
OptimState
.
SCALED
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment