Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
001ca624
Commit
001ca624
authored
Mar 04, 2022
by
ver217
Committed by
Frank Lee
Mar 11, 2022
Browse files
impl shard optim v2 and add unit test
parent
74f77e31
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
97 additions
and
6 deletions
+97
-6
colossalai/engine/ophooks/_shard_param_ophook.py
colossalai/engine/ophooks/_shard_param_ophook.py
+13
-2
colossalai/zero/sharded_optim/__init__.py
colossalai/zero/sharded_optim/__init__.py
+2
-1
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+14
-3
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+68
-0
No files found.
colossalai/engine/ophooks/_shard_param_ophook.py
View file @
001ca624
import
torch
import
torch
from
.
import
BaseOpHook
import
torch.distributed
as
dist
from
colossalai.registry
import
OPHOOKS
from
colossalai.registry
import
OPHOOKS
from
.
import
BaseOpHook
@
OPHOOKS
.
register_module
@
OPHOOKS
.
register_module
class
ShardParamHook
(
BaseOpHook
):
class
ShardParamHook
(
BaseOpHook
):
"""
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -17,25 +21,32 @@ class ShardParamHook(BaseOpHook):
...
@@ -17,25 +21,32 @@ class ShardParamHook(BaseOpHook):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
gather
()
param
.
ca_attr
.
gather
()
if
dist
.
get_rank
()
==
0
:
print
(
f
'
{
param
.
_name
}
pre fwd shape
{
param
.
ca_attr
.
payload
(
"cpu"
).
shape
}
'
)
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
shard
()
param
.
ca_attr
.
shard
()
if
dist
.
get_rank
()
==
0
:
print
(
f
'
{
param
.
_name
}
post fwd shape
{
param
.
ca_attr
.
payload
(
"cpu"
).
shape
}
'
)
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
gather
()
param
.
ca_attr
.
gather
()
if
dist
.
get_rank
()
==
0
:
print
(
f
'
{
param
.
_name
}
pre bwd shape
{
param
.
ca_attr
.
payload
(
"cpu"
).
shape
}
'
)
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
shard
()
param
.
ca_attr
.
shard
()
if
dist
.
get_rank
()
==
0
:
print
(
f
'
{
param
.
_name
}
post bwd shape
{
param
.
ca_attr
.
payload
(
"cpu"
).
shape
}
'
)
def
pre_iter
(
self
):
def
pre_iter
(
self
):
pass
pass
def
post_iter
(
self
):
def
post_iter
(
self
):
pass
pass
colossalai/zero/sharded_optim/__init__.py
View file @
001ca624
from
.sharded_optim
import
ShardedOptimizer
from
.sharded_optim
import
ShardedOptimizer
from
.sharded_optim_v2
import
ShardedOptimizerV2
__all__
=
[
'ShardedOptimizer'
]
__all__
=
[
'ShardedOptimizer'
,
'ShardedOptimizerV2'
]
\ No newline at end of file
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
001ca624
...
@@ -14,6 +14,7 @@ from torch.distributed import ProcessGroup
...
@@ -14,6 +14,7 @@ from torch.distributed import ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
..sharded_model._zero3_utils
import
free_storage
from
._utils
import
has_inf_or_nan
from
._utils
import
has_inf_or_nan
...
@@ -62,6 +63,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -62,6 +63,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if
hasattr
(
p
,
'ca_attr'
):
if
hasattr
(
p
,
'ca_attr'
):
assert
p
.
ca_attr
.
is_sharded
,
'ShardedAdam can be only used with sharded model'
assert
p
.
ca_attr
.
is_sharded
,
'ShardedAdam can be only used with sharded model'
self
.
master_params
[
p
]
=
p
.
ca_attr
.
payload
(
self
.
device
)
self
.
master_params
[
p
]
=
p
.
ca_attr
.
payload
(
self
.
device
)
if
dist
.
get_rank
()
==
0
:
print
(
f
'load payload
{
p
.
_name
}
{
self
.
master_params
[
p
].
shape
}
'
)
else
:
else
:
self
.
master_params
[
p
]
=
p
.
data
.
to
(
device
=
self
.
device
)
self
.
master_params
[
p
]
=
p
.
data
.
to
(
device
=
self
.
device
)
if
torch
.
is_floating_point
(
self
.
master_params
[
p
])
and
self
.
master_params
[
p
].
dtype
!=
torch
.
float
:
if
torch
.
is_floating_point
(
self
.
master_params
[
p
])
and
self
.
master_params
[
p
].
dtype
!=
torch
.
float
:
...
@@ -84,19 +87,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -84,19 +87,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
p
.
data
=
self
.
master_params
[
p
]
p
.
data
=
self
.
master_params
[
p
]
ret
=
self
.
optim
.
step
(
*
args
,
**
kwargs
)
ret
=
self
.
optim
.
step
(
*
args
,
**
kwargs
)
# Write master param to payload
and set p.data to None
# Write master param to payload
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
if
hasattr
(
p
,
'ca_attr'
):
if
hasattr
(
p
,
'ca_attr'
):
# TODO: update payload
if
dist
.
get_rank
()
==
0
:
p
.
data
=
None
print
(
f
'write
{
p
.
_name
}
{
p
.
shape
}
orig_shape
{
p
.
ca_attr
.
_origin_shape
}
\
payload shape
{
p
.
ca_attr
.
_param_payload
.
shape
}
sharded
{
p
.
ca_attr
.
is_sharded
}
'
)
p
.
ca_attr
.
set_payload
(
p
.
data
)
# We cannot set p.data to None directly, so we free storage
free_storage
(
p
.
data
)
return
ret
return
ret
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
loss
=
self
.
loss_scale
*
loss
loss
=
self
.
loss_scale
*
loss
self
.
optim_state
=
OptimState
.
SCALED
self
.
optim_state
=
OptimState
.
SCALED
if
self
.
model_is_sharded
:
if
self
.
model_is_sharded
:
if
dist
.
get_rank
()
==
0
:
print
(
'sharded model backward'
)
self
.
model
.
backward
(
loss
)
self
.
model
.
backward
(
loss
)
if
dist
.
get_rank
()
==
0
:
print
(
'sharded model backward done'
)
else
:
else
:
super
().
backward
(
loss
)
super
().
backward
(
loss
)
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
0 → 100644
View file @
001ca624
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
torch.optim
import
Adam
from
common
import
(
CONFIG
,
Net
,
check_grads
,
check_grads_padding
,
check_params
,
check_params_padding
)
def
run_step
(
model
,
optimizer
,
x
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
y
=
model
(
x
)
loss
=
y
.
sum
()
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
optimizer
.
backward
(
loss
)
for
p
in
model
.
parameters
():
assert
p
.
ca_attr
.
is_sharded
else
:
loss
.
backward
()
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
Net
(
checkpoint
=
True
).
cuda
()
zero_model
=
copy
.
deepcopy
(
model
)
zero_model
=
ShardedModelV2
(
zero_model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
for
n
,
p
in
zero_model
.
named_parameters
():
p
.
_name
=
n
optim
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
ShardedOptimizerV2
(
Adam
(
zero_model
.
parameters
(),
lr
=
1e-3
),
zero_model
)
for
_
in
range
(
2
):
x
=
torch
.
rand
(
2
,
5
).
cuda
()
run_step
(
zero_model
,
sharded_optim
,
x
,
False
)
run_step
(
model
,
optim
,
x
,
False
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
)
check_params_padding
(
model
,
zero_model
)
else
:
check_grads
(
model
,
zero_model
)
check_params
(
model
,
zero_model
)
@
pytest
.
mark
.
dist
def
test_sharded_optim_v2
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_sharded_optim_v2
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment