Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
237d08e7
Unverified
Commit
237d08e7
authored
Mar 17, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 17, 2022
Browse files
[zero] hybrid cpu adam (#445)
parent
b72b8445
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
55 deletions
+81
-55
colossalai/nn/optimizer/cpu_adam.py
colossalai/nn/optimizer/cpu_adam.py
+72
-45
colossalai/zero/__init__.py
colossalai/zero/__init__.py
+9
-10
No files found.
colossalai/nn/optimizer/cpu_adam.py
View file @
237d08e7
import
torch
import
torch
import
math
class
CPUAdam
(
torch
.
optim
.
Optimizer
):
class
CPUAdam
(
torch
.
optim
.
Optimizer
):
...
@@ -8,19 +9,18 @@ class CPUAdam(torch.optim.Optimizer):
...
@@ -8,19 +9,18 @@ class CPUAdam(torch.optim.Optimizer):
model_params
,
model_params
,
lr
=
1e-3
,
lr
=
1e-3
,
bias_correction
=
True
,
bias_correction
=
True
,
betas
=
(
0.9
,
betas
=
(
0.9
,
0.999
),
0.999
),
eps
=
1e-8
,
eps
=
1e-8
,
weight_decay
=
0
,
weight_decay
=
0
,
adamw_mode
=
True
,
adamw_mode
=
True
,
loss_scale
=-
1
,
loss_scale
=-
1
,
simd_log
=
False
):
simd_log
=
False
):
"""
default_args
=
dict
(
lr
=
lr
,
An implementation equivalent to `torch.optim.Adam`.
betas
=
betas
,
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
eps
=
eps
,
The sharded param of model_params can resident on both CPU and CUDA.
weight_decay
=
weight_decay
,
"""
bias_correction
=
bias_correction
)
default_args
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
bias_correction
=
bias_correction
)
super
(
CPUAdam
,
self
).
__init__
(
model_params
,
default_args
)
super
(
CPUAdam
,
self
).
__init__
(
model_params
,
default_args
)
self
.
opt_id
=
CPUAdam
.
optimizer_id
self
.
opt_id
=
CPUAdam
.
optimizer_id
CPUAdam
.
optimizer_id
=
CPUAdam
.
optimizer_id
+
1
CPUAdam
.
optimizer_id
=
CPUAdam
.
optimizer_id
+
1
...
@@ -31,18 +31,45 @@ class CPUAdam(torch.optim.Optimizer):
...
@@ -31,18 +31,45 @@ class CPUAdam(torch.optim.Optimizer):
except
ImportError
:
except
ImportError
:
raise
ImportError
(
'Please install colossalai from source code to use CPUAdam'
)
raise
ImportError
(
'Please install colossalai from source code to use CPUAdam'
)
self
.
cpu_adam_op
=
cpu_adam
self
.
cpu_adam_op
=
cpu_adam
self
.
cpu_adam_op
.
create_adam
(
self
.
opt_id
,
self
.
cpu_adam_op
.
create_adam
(
self
.
opt_id
,
lr
,
betas
[
0
],
betas
[
1
],
eps
,
weight_decay
,
adamw_mode
,
simd_log
)
lr
,
betas
[
0
],
betas
[
1
],
eps
,
weight_decay
,
adamw_mode
,
simd_log
)
def
__del__
(
self
):
def
__del__
(
self
):
self
.
cpu_adam_op
.
destroy_adam
(
self
.
opt_id
)
self
.
cpu_adam_op
.
destroy_adam
(
self
.
opt_id
)
def
torch_adam_update
(
self
,
data
,
grad
,
exp_avg
,
exp_avg_sq
,
lr
,
beta1
,
beta2
,
eps
,
weight_decay
,
bias_correction1
,
bias_correction2
,
loss_scale
,
use_adamw
=
False
):
if
loss_scale
is
not
None
:
grad
.
div_
(
loss_scale
)
if
weight_decay
!=
0
:
if
use_adamw
:
data
.
mul_
(
1
-
lr
*
weight_decay
)
else
:
grad
=
grad
.
add
(
data
,
alpha
=
weight_decay
)
# Decay the first and second moment running average coefficient
exp_avg
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
# TODO(jiaruifang) dose not support amsgrad
denom
=
(
exp_avg_sq
.
sqrt
()
/
math
.
sqrt
(
bias_correction2
)).
add_
(
eps
)
step_size
=
lr
/
bias_correction1
data
.
addcdiv_
(
exp_avg
,
denom
,
value
=-
step_size
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
step
(
self
,
closure
=
None
):
def
step
(
self
,
closure
=
None
):
...
@@ -51,47 +78,47 @@ class CPUAdam(torch.optim.Optimizer):
...
@@ -51,47 +78,47 @@ class CPUAdam(torch.optim.Optimizer):
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
loss
=
closure
()
loss
=
closure
()
# intended device for step
for
_
,
group
in
enumerate
(
self
.
param_groups
):
device
=
torch
.
device
(
'cpu'
)
for
_
,
p
in
enumerate
(
group
[
'params'
]):
for
group_id
,
group
in
enumerate
(
self
.
param_groups
):
for
param_id
,
p
in
enumerate
(
group
[
'params'
]):
if
p
.
grad
is
None
:
if
p
.
grad
is
None
:
continue
continue
assert
p
.
device
==
device
,
f
"CPUAdam param is on
{
p
.
device
}
and must be 'cpu', make "
\
"sure the cpu_offload is Ture"
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
# State initialization
target_device
=
p
.
device
if
len
(
state
)
==
0
:
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
state
[
'step'
]
=
0
# gradient momentums
# gradient momentums
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
,
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
,
dtype
=
torch
.
float
,
device
=
target_device
)
dtype
=
torch
.
float
,
device
=
device
)
# gradient variances
# gradient variances
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
,
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
,
dtype
=
torch
.
float
,
device
=
target_device
)
dtype
=
torch
.
float
,
device
=
device
)
# memory_format=torch.preserve_format)
state
[
'step'
]
+=
1
state
[
'step'
]
+=
1
beta1
,
beta2
=
group
[
'betas'
]
beta1
,
beta2
=
group
[
'betas'
]
self
.
cpu_adam_op
.
adam_update
(
self
.
opt_id
,
if
target_device
.
type
==
'cpu'
:
state
[
'step'
],
assert
state
[
'exp_avg'
].
device
.
type
==
'cpu'
,
"exp_avg should stay on cpu"
group
[
'lr'
],
assert
state
[
'exp_avg_sq'
].
device
.
type
==
'cpu'
,
"exp_avg should stay on cpu"
beta1
,
self
.
cpu_adam_op
.
adam_update
(
self
.
opt_id
,
state
[
'step'
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
beta2
,
group
[
'weight_decay'
],
group
[
'bias_correction'
],
p
.
data
,
p
.
grad
.
data
,
group
[
'eps'
],
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
],
self
.
loss_scale
)
group
[
'weight_decay'
],
elif
target_device
.
type
==
'cuda'
:
group
[
'bias_correction'
],
# FIXME() prepare grad on cuda
p
.
data
,
if
p
.
grad
.
device
.
type
==
'cpu'
:
p
.
grad
.
data
,
p
.
grad
=
p
.
grad
.
to
(
target_device
)
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
],
assert
state
[
'exp_avg'
].
device
.
type
==
'cuda'
,
"exp_avg should stay on cuda"
self
.
loss_scale
)
assert
state
[
'exp_avg_sq'
].
device
.
type
==
'cuda'
,
"exp_avg should stay on cuda"
bias_correction1
=
1
-
beta1
**
state
[
'step'
]
bias_correction2
=
1
-
beta2
**
state
[
'step'
]
# adam on cuda
self
.
torch_adam_update
(
p
.
data
,
p
.
grad
.
data
,
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
group
[
'weight_decay'
],
bias_correction1
,
bias_correction2
,
self
.
loss_scale
)
else
:
raise
RuntimeError
return
loss
return
loss
colossalai/zero/__init__.py
View file @
237d08e7
from
asyncio.log
import
logger
from
typing
import
Callable
from
distutils.command.config
import
config
import
torch
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
TensorShardStrategy
import
torch
import
torch.nn
as
nn
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
torch.optim
import
Optimizer
from
.sharded_model
import
ShardedModel
from
.sharded_optim
import
ShardedOptimizer
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
typing
import
Callable
,
Type
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
.sharded_model
import
ShardedModel
from
.sharded_optim
import
ShardedOptimizer
def
convert_to_zero_v2
(
model_builder
:
Callable
,
model_config
,
optimizer_config
)
->
(
ShardedModelV2
,
ShardedOptimizerV2
):
def
convert_to_zero_v2
(
model_builder
:
Callable
,
model_config
,
optimizer_config
)
->
(
ShardedModelV2
,
ShardedOptimizerV2
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment