Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
44e4891f
Commit
44e4891f
authored
Mar 10, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] able to place params on cpu after zero init context (#365)
* place params on cpu after zero init context * polish code
parent
b66f3b99
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
58 additions
and
20 deletions
+58
-20
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+7
-1
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+18
-10
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+2
-1
colossalai/zero/sharded_param/sharded_tensor.py
colossalai/zero/sharded_param/sharded_tensor.py
+7
-0
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+16
-7
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+8
-1
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
44e4891f
import
torch
import
torch
from
colossalai.registry
import
OPHOOKS
from
colossalai.registry
import
OPHOOKS
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.utils
import
get_current_device
from
._base_ophook
import
BaseOpHook
from
._base_ophook
import
BaseOpHook
...
@@ -14,11 +14,15 @@ class ZeroHook(BaseOpHook):
...
@@ -14,11 +14,15 @@ class ZeroHook(BaseOpHook):
def
__init__
(
self
,
shard_strategy
:
BaseShardStrategy
):
def
__init__
(
self
,
shard_strategy
:
BaseShardStrategy
):
super
().
__init__
()
super
().
__init__
()
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self
.
computing_device
=
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
data
.
payload
param
.
data
=
param
.
col_attr
.
data
.
payload
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
...
@@ -31,6 +35,8 @@ class ZeroHook(BaseOpHook):
...
@@ -31,6 +35,8 @@ class ZeroHook(BaseOpHook):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
data
.
payload
param
.
data
=
param
.
col_attr
.
data
.
payload
# Store local accumulated grad shard
# Store local accumulated grad shard
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
...
...
colossalai/zero/init_ctx/init_context.py
View file @
44e4891f
import
functools
import
functools
import
torch
import
torch
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
...
@@ -82,6 +81,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -82,6 +81,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
1. Convert the model to fp16.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
3. Shard the param and grad according to flags.
target_device: the device where param data after exiting the context
shard_strategy: shard strategy instance
shard_param: is param sharded after exiting the context
shard_grad: is param sharded after exiting the context
rm_torch_payload_on_the_fly:
rm_torch_payload_on_the_fly:
True: remove tensor payload on param.data after module init finished.
True: remove tensor payload on param.data after module init finished.
False: remove tensor payload on param.data afther the context exist.
False: remove tensor payload on param.data afther the context exist.
...
@@ -91,18 +96,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -91,18 +96,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
def
__init__
(
self
,
def
__init__
(
self
,
convert_fp16
:
bool
,
convert_fp16
:
bool
,
convert_cuda
:
bool
,
target_device
:
torch
.
device
,
shard_strategy
:
BaseShardStrategy
,
shard_strategy
:
BaseShardStrategy
,
shard_param
:
bool
=
False
,
shard_param
:
bool
=
False
,
shard_grad
:
bool
=
False
,
shard_grad
:
bool
=
False
,
rm_torch_payload_on_the_fly
=
False
):
rm_torch_payload_on_the_fly
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
convert_fp16
=
convert_fp16
self
.
convert_fp16
=
convert_fp16
self
.
convert_cuda
=
convert_cuda
self
.
target_device
=
target_device
self
.
shard_param
=
shard_param
self
.
shard_param
=
shard_param
self
.
shard_grad
=
shard_grad
self
.
shard_grad
=
shard_grad
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
self
.
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
# FIXME(jiaruifang) now setting it to True is invalid.
self
.
rm_torch_payload_on_the_fly
=
False
self
.
initialized_param_list
=
[]
self
.
initialized_param_list
=
[]
def
_post_context_exec
(
self
):
def
_post_context_exec
(
self
):
...
@@ -123,17 +129,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -123,17 +129,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
hasattr
(
param
,
'col_attr'
):
if
hasattr
(
param
,
'col_attr'
):
continue
continue
if
self
.
convert_cuda
:
target_device
=
self
.
target_device
target_device
=
get_current_device
()
else
:
target_device
=
param
.
data
.
device
# convert to fp16
and cuda
if necessary
# convert to fp16 if necessary
if
self
.
convert_fp16
:
if
self
.
convert_fp16
:
param
.
data
=
param
.
data
.
to
(
torch
.
half
)
.
to
(
target_device
)
param
.
data
=
param
.
data
.
to
(
torch
.
half
)
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
).
to
(
target_device
)
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
).
to
(
target_device
)
# move torch parameters to the target device
param
.
data
=
param
.
data
.
to
(
target_device
)
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
target_device
)
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
self
.
initialized_param_list
.
append
(
param
)
self
.
initialized_param_list
.
append
(
param
)
...
...
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
44e4891f
...
@@ -30,7 +30,7 @@ class TensorShardStrategy(BaseShardStrategy):
...
@@ -30,7 +30,7 @@ class TensorShardStrategy(BaseShardStrategy):
def
_gather_tensor
(
self
,
t
:
ShardedTensor
):
def
_gather_tensor
(
self
,
t
:
ShardedTensor
):
if
not
t
.
is_sharded
:
if
not
t
.
is_sharded
:
return
return
target_device
=
t
.
device
buffer_list
=
[]
buffer_list
=
[]
payload_numel
=
t
.
payload
.
numel
()
payload_numel
=
t
.
payload
.
numel
()
for
i
in
range
(
self
.
world_size
):
for
i
in
range
(
self
.
world_size
):
...
@@ -45,4 +45,5 @@ class TensorShardStrategy(BaseShardStrategy):
...
@@ -45,4 +45,5 @@ class TensorShardStrategy(BaseShardStrategy):
async_op
=
False
)
async_op
=
False
)
gathered_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
t
.
origin_numel
).
reshape
(
t
.
origin_shape
)
gathered_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
t
.
origin_numel
).
reshape
(
t
.
origin_shape
)
t
.
reset_payload
(
gathered_payload
)
t
.
reset_payload
(
gathered_payload
)
t
.
to
(
target_device
)
t
.
is_sharded
=
False
t
.
is_sharded
=
False
colossalai/zero/sharded_param/sharded_tensor.py
View file @
44e4891f
...
@@ -47,11 +47,18 @@ class ShardedTensor(object):
...
@@ -47,11 +47,18 @@ class ShardedTensor(object):
del
self
.
_payload
del
self
.
_payload
self
.
_payload
=
tensor
self
.
_payload
=
tensor
@
property
def
device
(
self
):
return
self
.
_payload
.
device
@
property
@
property
def
dtype
(
self
):
def
dtype
(
self
):
assert
self
.
_payload
.
dtype
==
self
.
_origin_dtype
assert
self
.
_payload
.
dtype
==
self
.
_origin_dtype
return
self
.
_origin_dtype
return
self
.
_origin_dtype
def
to
(
self
,
device
:
torch
.
device
):
self
.
_payload
=
self
.
_payload
.
to
(
device
)
@
property
@
property
def
shape
(
self
):
def
shape
(
self
):
return
self
.
_payload
.
shape
return
self
.
_payload
.
shape
tests/test_zero_data_parallel/test_init_context.py
View file @
44e4891f
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
from
functools
import
partial
from
functools
import
partial
import
colossalai
import
colossalai
from
colossalai.utils.cuda
import
get_current_device
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
...
@@ -17,13 +18,13 @@ from common import CONFIG
...
@@ -17,13 +18,13 @@ from common import CONFIG
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
init_device
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
for
get_components_func
in
non_distributed_component_funcs
:
for
get_components_func
in
non_distributed_component_funcs
:
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
model_builder
,
_
,
_
,
_
,
_
=
get_components_func
()
with
ZeroInitContext
(
convert_fp16
=
True
,
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
Tru
e
,
target_device
=
init_devic
e
,
shard_strategy
=
TensorShardStrategy
(),
shard_strategy
=
TensorShardStrategy
(),
shard_param
=
True
):
shard_param
=
True
):
model
=
model_builder
(
checkpoint
=
True
)
model
=
model_builder
(
checkpoint
=
True
)
...
@@ -32,18 +33,26 @@ def run_dist(rank, world_size, port):
...
@@ -32,18 +33,26 @@ def run_dist(rank, world_size, port):
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col_attr'
)
assert
param
.
col_attr
.
data
.
dtype
==
torch
.
half
assert
param
.
col_attr
.
data
.
dtype
==
torch
.
half
assert
param
.
col_attr
.
data
.
is_sharded
assert
param
.
col_attr
.
data
.
is_sharded
assert
param
.
col_attr
.
data
.
payload
.
device
.
type
==
'cuda'
assert
param
.
col_attr
.
data
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
data
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
print
(
f
'cpu usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
}
'
)
print
(
f
'cuda usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
'
)
print
(
f
'cuda usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
'
)
if
init_device
.
type
==
'cuda'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
elif
init_device
.
type
==
'cpu'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
>
0
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
def
test_zero_init_context
(
world_size
):
@
pytest
.
mark
.
parametrize
(
"init_device"
,
[
torch
.
device
(
'cpu'
),
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)])
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
def
test_zero_init_context
(
world_size
,
init_device
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
init_device
=
init_device
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_zero_init_context
(
2
)
test_zero_init_context
(
2
,
torch
.
device
(
'cpu'
))
test_zero_init_context
(
2
,
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
))
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
44e4891f
...
@@ -5,6 +5,7 @@ import copy
...
@@ -5,6 +5,7 @@ import copy
from
functools
import
partial
from
functools
import
partial
import
pytest
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
...
@@ -30,8 +31,14 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
...
@@ -30,8 +31,14 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
rm_torch_payload_on_the_fly
=
False
if
use_zero_init_ctx
:
if
use_zero_init_ctx
:
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
device
(
'cpu'
),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment