Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
795210dd
Commit
795210dd
authored
Mar 03, 2022
by
ver217
Committed by
Frank Lee
Mar 11, 2022
Browse files
add fp32 master params in sharded adam
parent
a109225b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
34 deletions
+14
-34
colossalai/zero/sharded_optim/sharded_adam.py
colossalai/zero/sharded_optim/sharded_adam.py
+14
-34
No files found.
colossalai/zero/sharded_optim/sharded_adam.py
View file @
795210dd
from
enum
import
Enum
from
typing
import
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -11,6 +11,7 @@ from colossalai.nn.optimizer import ColossalaiOptimizer
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
._utils
import
has_inf_or_nan
...
...
@@ -39,7 +40,7 @@ class ShardedAdam(ColossalaiOptimizer):
super
().
__init__
(
adam_optim
)
self
.
model
:
Union
[
nn
.
Module
,
ShardedModelV2
]
=
sharded_model
self
.
model_is_sharded
=
isinstance
(
sharded_model
,
ShardedModelV2
)
self
.
state_
device
=
torch
.
cuda
.
current_device
()
if
not
cpu_offload
else
torch
.
device
(
'cpu'
)
self
.
device
=
torch
.
cuda
.
current_device
()
if
not
cpu_offload
else
torch
.
device
(
'cpu'
)
self
.
optim_state
:
OptimState
=
OptimState
.
UNSCALED
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
mp_process_group
=
mp_process_group
or
gpc
.
get_group
(
ParallelMode
.
MODEL
)
...
...
@@ -51,35 +52,18 @@ class ShardedAdam(ColossalaiOptimizer):
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
max_scale
=
max_scale
)
self
.
_found_overflow
:
Tensor
=
torch
.
FloatTensor
([
0
]).
to
(
self
.
state_device
)
self
.
_found_overflow
:
Tensor
=
torch
.
FloatTensor
([
0
]).
to
(
self
.
device
)
# Store fp32 params
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
# Early state initialization
for
group
in
adam_optim
.
param_groups
:
for
p
in
group
[
'params'
]:
state_shape
=
p
.
shape
if
hasattr
(
p
,
'ca_attr'
):
assert
p
.
ca_attr
.
is_sharded
,
'ShardedAdam can be only used with sharded model'
# TODO: use payload shape
state_shape
=
p
.
ca_attr
.
payload
(
self
.
state_device
)
state
=
adam_optim
.
state
[
p
]
assert
len
(
state
)
==
0
,
'adam optimizer initialized'
state
[
'step'
]
=
0
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros
(
state_shape
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float
,
device
=
self
.
state_device
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros
(
state_shape
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float
,
device
=
self
.
state_device
)
if
group
[
'amsgrad'
]:
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
'max_exp_avg_sq'
]
=
torch
.
zeros
(
state_shape
,
memory_format
=
torch
.
preserve_format
,
dtype
=
torch
.
float
,
device
=
self
.
state_device
)
self
.
master_params
[
p
]
=
p
.
ca_attr
.
payload
(
self
.
device
).
to
(
torch
.
float
)
else
:
self
.
master_params
[
p
]
=
p
.
data
.
to
(
torch
.
float
)
def
step
(
self
,
*
args
,
**
kwargs
):
# unscale grads if scaled
...
...
@@ -93,19 +77,15 @@ class ShardedAdam(ColossalaiOptimizer):
self
.
zero_grad
()
return
# Write
payload back
to p.data
# Write
master param
to p.data
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
data
=
p
.
data
if
hasattr
(
p
,
'ca_attr'
):
data
=
p
.
ca_attr
.
payload
(
self
.
state_device
)
if
torch
.
is_floating_point
(
data
)
and
data
.
dtype
!=
torch
.
float
:
data
=
data
.
to
(
torch
.
float
)
p
.
data
=
data
p
.
data
=
self
.
master_params
[
p
]
ret
=
self
.
optim
.
step
(
*
args
,
**
kwargs
)
#
S
et p.data to None
#
Write master param to payload and s
et p.data to None
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
# TODO: update payload
p
.
data
=
None
return
ret
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment