Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
44e4891f
Commit
44e4891f
authored
Mar 10, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] able to place params on cpu after zero init context (#365)
* place params on cpu after zero init context * polish code
parent
b66f3b99
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
58 additions
and
20 deletions
+58
-20
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+7
-1
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+18
-10
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+2
-1
colossalai/zero/sharded_param/sharded_tensor.py
colossalai/zero/sharded_param/sharded_tensor.py
+7
-0
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+16
-7
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+8
-1
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
44e4891f
import
torch
from
colossalai.registry
import
OPHOOKS
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.utils
import
get_current_device
from
._base_ophook
import
BaseOpHook
...
...
@@ -14,11 +14,15 @@ class ZeroHook(BaseOpHook):
def
__init__
(
self
,
shard_strategy
:
BaseShardStrategy
):
super
().
__init__
()
self
.
shard_strategy
=
shard_strategy
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self
.
computing_device
=
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
data
.
payload
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
...
...
@@ -31,6 +35,8 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
data
.
payload
# Store local accumulated grad shard
if
param
.
grad
is
not
None
:
...
...
colossalai/zero/init_ctx/init_context.py
View file @
44e4891f
import
functools
import
torch
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
...
...
@@ -82,6 +81,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
target_device: the device where param data after exiting the context
shard_strategy: shard strategy instance
shard_param: is param sharded after exiting the context
shard_grad: is param sharded after exiting the context
rm_torch_payload_on_the_fly:
True: remove tensor payload on param.data after module init finished.
False: remove tensor payload on param.data afther the context exist.
...
...
@@ -91,18 +96,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
def
__init__
(
self
,
convert_fp16
:
bool
,
convert_cuda
:
bool
,
target_device
:
torch
.
device
,
shard_strategy
:
BaseShardStrategy
,
shard_param
:
bool
=
False
,
shard_grad
:
bool
=
False
,
rm_torch_payload_on_the_fly
=
False
):
super
().
__init__
()
self
.
convert_fp16
=
convert_fp16
self
.
convert_cuda
=
convert_cuda
self
.
target_device
=
target_device
self
.
shard_param
=
shard_param
self
.
shard_grad
=
shard_grad
self
.
shard_strategy
=
shard_strategy
self
.
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
# FIXME(jiaruifang) now setting it to True is invalid.
self
.
rm_torch_payload_on_the_fly
=
False
self
.
initialized_param_list
=
[]
def
_post_context_exec
(
self
):
...
...
@@ -123,17 +129,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
hasattr
(
param
,
'col_attr'
):
continue
if
self
.
convert_cuda
:
target_device
=
get_current_device
()
else
:
target_device
=
param
.
data
.
device
target_device
=
self
.
target_device
# convert to fp16
and cuda
if necessary
# convert to fp16 if necessary
if
self
.
convert_fp16
:
param
.
data
=
param
.
data
.
to
(
torch
.
half
)
.
to
(
target_device
)
param
.
data
=
param
.
data
.
to
(
torch
.
half
)
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
).
to
(
target_device
)
# move torch parameters to the target device
param
.
data
=
param
.
data
.
to
(
target_device
)
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
target_device
)
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
self
.
initialized_param_list
.
append
(
param
)
...
...
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
44e4891f
...
...
@@ -30,7 +30,7 @@ class TensorShardStrategy(BaseShardStrategy):
def
_gather_tensor
(
self
,
t
:
ShardedTensor
):
if
not
t
.
is_sharded
:
return
target_device
=
t
.
device
buffer_list
=
[]
payload_numel
=
t
.
payload
.
numel
()
for
i
in
range
(
self
.
world_size
):
...
...
@@ -45,4 +45,5 @@ class TensorShardStrategy(BaseShardStrategy):
async_op
=
False
)
gathered_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
t
.
origin_numel
).
reshape
(
t
.
origin_shape
)
t
.
reset_payload
(
gathered_payload
)
t
.
to
(
target_device
)
t
.
is_sharded
=
False
colossalai/zero/sharded_param/sharded_tensor.py
View file @
44e4891f
...
...
@@ -47,11 +47,18 @@ class ShardedTensor(object):
del
self
.
_payload
self
.
_payload
=
tensor
@
property
def
device
(
self
):
return
self
.
_payload
.
device
@
property
def
dtype
(
self
):
assert
self
.
_payload
.
dtype
==
self
.
_origin_dtype
return
self
.
_origin_dtype
def
to
(
self
,
device
:
torch
.
device
):
self
.
_payload
=
self
.
_payload
.
to
(
device
)
@
property
def
shape
(
self
):
return
self
.
_payload
.
shape
tests/test_zero_data_parallel/test_init_context.py
View file @
44e4891f
...
...
@@ -4,6 +4,7 @@
from
functools
import
partial
import
colossalai
from
colossalai.utils.cuda
import
get_current_device
import
pytest
import
torch
import
torch.multiprocessing
as
mp
...
...
@@ -17,13 +18,13 @@ from common import CONFIG
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
init_device
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
for
get_components_func
in
non_distributed_component_funcs
:
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
Tru
e
,
target_device
=
init_devic
e
,
shard_strategy
=
TensorShardStrategy
(),
shard_param
=
True
):
model
=
model_builder
(
checkpoint
=
True
)
...
...
@@ -32,18 +33,26 @@ def run_dist(rank, world_size, port):
assert
hasattr
(
param
,
'col_attr'
)
assert
param
.
col_attr
.
data
.
dtype
==
torch
.
half
assert
param
.
col_attr
.
data
.
is_sharded
assert
param
.
col_attr
.
data
.
payload
.
device
.
type
==
'cuda'
assert
param
.
col_attr
.
data
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
data
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
print
(
f
'cpu usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
}
'
)
print
(
f
'cuda usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
'
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
if
init_device
.
type
==
'cuda'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
elif
init_device
.
type
==
'cpu'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
>
0
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
def
test_zero_init_context
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
@
pytest
.
mark
.
parametrize
(
"init_device"
,
[
torch
.
device
(
'cpu'
),
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)])
def
test_zero_init_context
(
world_size
,
init_device
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
init_device
=
init_device
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_init_context
(
2
)
test_zero_init_context
(
2
,
torch
.
device
(
'cpu'
))
test_zero_init_context
(
2
,
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
))
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
44e4891f
...
...
@@ -5,6 +5,7 @@ import copy
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
...
...
@@ -30,8 +31,14 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
rm_torch_payload_on_the_fly
=
False
if
use_zero_init_ctx
:
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
device
(
'cpu'
),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment