Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6a88bae4
Commit
6a88bae4
authored
Jun 30, 2023
by
Frank Lee
Browse files
[shardformer] integrate with data parallelism (#4103)
parent
f3b6aaa6
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
97 additions
and
50 deletions
+97
-50
colossalai/shardformer/shard/shard_config.py
colossalai/shardformer/shard/shard_config.py
+7
-9
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+3
-8
colossalai/shardformer/shard/shardformer.py
colossalai/shardformer/shard/shardformer.py
+2
-23
tests/test_shardformer/test_model/_utils.py
tests/test_shardformer/test_model/_utils.py
+2
-4
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+1
-1
tests/test_shardformer/test_model/test_shard_bloom.py
tests/test_shardformer/test_model/test_shard_bloom.py
+1
-1
tests/test_shardformer/test_model/test_shard_gpt2.py
tests/test_shardformer/test_model/test_shard_gpt2.py
+1
-1
tests/test_shardformer/test_model/test_shard_llama.py
tests/test_shardformer/test_model/test_shard_llama.py
+1
-1
tests/test_shardformer/test_model/test_shard_opt.py
tests/test_shardformer/test_model/test_shard_opt.py
+1
-1
tests/test_shardformer/test_model/test_shard_t5.py
tests/test_shardformer/test_model/test_shard_t5.py
+1
-1
tests/test_shardformer/test_with_torch_ddp.py
tests/test_shardformer/test_with_torch_ddp.py
+77
-0
No files found.
colossalai/shardformer/shard/shard_config.py
View file @
6a88bae4
from
dataclasses
import
dataclass
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
colossalai.cluster.dist_coordinator
import
DistCoordinator
__all__
=
[
'ShardConfig'
]
...
...
@@ -11,10 +14,10 @@ class ShardConfig:
The config for sharding the huggingface model
Args:
tensor_parallel_
size (int): The size of tensor parallel
tensor_parallel_
process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
"""
tensor_parallel_
size
:
int
tensor_parallel_
process_group
:
int
=
None
enable_fused_normalization
:
bool
=
False
# TODO: add support for tensor parallel
...
...
@@ -25,10 +28,5 @@ class ShardConfig:
# gather_output: bool = True
def
__post_init__
(
self
):
coordinator
=
DistCoordinator
()
# ensure the parallel size can match the world size
world_size
=
coordinator
.
world_size
self
.
data_parallel_size
=
world_size
//
self
.
tensor_parallel_size
assert
world_size
==
self
.
data_parallel_size
*
self
.
tensor_parallel_size
,
\
f
"The world size (
{
world_size
}
) should be divisible by the data parallel size
{
self
.
data_parallel_size
}
and tensor parallel size
{
self
.
tensor_parallel_size
}
"
# get the parallel size
self
.
tensor_parallel_size
=
dist
.
get_world_size
(
self
.
tensor_parallel_process_group
)
colossalai/shardformer/shard/sharder.py
View file @
6a88bae4
...
...
@@ -22,16 +22,10 @@ class ModelSharder(object):
shard_config: The setting of distributed model
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
policy
:
Policy
,
shard_config
:
ShardConfig
=
None
,
# TODO
pg_manager
:
ProcessGroupManager
=
None
)
->
None
:
def
__init__
(
self
,
model
:
nn
.
Module
,
policy
:
Policy
,
shard_config
:
ShardConfig
=
None
)
->
None
:
self
.
model
=
model
self
.
policy
=
get_autopolicy
(
self
.
model
)
if
policy
is
None
else
policy
self
.
shard_config
=
shard_config
self
.
pg_manager
=
pg_manager
def
shard
(
self
)
->
None
:
r
"""
...
...
@@ -198,7 +192,8 @@ class ModelSharder(object):
continue
try
:
replace_layer
=
target_module
.
from_native_module
(
native_sub_module
,
self
.
pg_manager
.
pg_store
[
'tp1d'
],
replace_layer
=
target_module
.
from_native_module
(
native_sub_module
,
self
.
shard_config
.
tensor_parallel_process_group
,
**
kwargs
)
except
Exception
as
e
:
raise
RuntimeError
(
...
...
colossalai/shardformer/shard/shardformer.py
View file @
6a88bae4
import
torch.nn
as
nn
from
torch.utils.data
import
Dataset
from
colossalai.cluster
import
DistCoordinator
,
ProcessGroupManager
from
colossalai.cluster
import
DistCoordinator
from
..policies.basepolicy
import
Policy
from
.shard_config
import
ShardConfig
...
...
@@ -28,7 +27,6 @@ class ShardFormer:
tensor_parallel_mode='1d',
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
model = shard_former.shard_model(org_model)
```
"""
...
...
@@ -41,19 +39,6 @@ class ShardFormer:
"""
self
.
coordinator
=
DistCoordinator
()
self
.
shard_config
=
shard_config
self
.
pg_manager
=
None
def
init_distributed
(
self
)
->
ProcessGroupManager
:
"""
Initialize the distributed process group according to the
"""
# create process group manager and 1d process group
# TODO: may need to support other parallel mode when the config has such as field
pg_manager
=
ProcessGroupManager
()
pg_manager
.
create_process_group
(
name
=
'tp1d'
,
ranks
=
range
(
self
.
coordinator
.
world_size
))
self
.
pg_manager
=
pg_manager
return
pg_manager
def
shard_model
(
self
,
model
:
nn
.
Module
,
policy
:
Policy
=
None
):
r
"""
...
...
@@ -64,12 +49,6 @@ class ShardFormer:
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
sharder
=
ModelSharder
(
model
=
model
,
shard_config
=
self
.
shard_config
,
policy
=
policy
,
pg_manager
=
self
.
pg_manager
)
sharder
=
ModelSharder
(
model
=
model
,
shard_config
=
self
.
shard_config
,
policy
=
policy
)
sharder
.
shard
()
return
model
def
shard_dataset
(
self
,
dataset
:
Dataset
):
"""
Shard dataset for DP
"""
pass
tests/test_shardformer/test_model/_utils.py
View file @
6a88bae4
...
...
@@ -3,17 +3,15 @@ import copy
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
def
build_model
(
world_size
,
model_fn
):
def
build_model
(
model_fn
):
# create new model
org_model
=
model_fn
().
cuda
()
# shard model
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
,
enable_fused_normalization
=
True
)
shard_config
=
ShardConfig
(
enable_fused_normalization
=
True
)
model_copy
=
copy
.
deepcopy
(
org_model
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
model_copy
).
cuda
()
return
org_model
,
sharded_model
...
...
tests/test_shardformer/test_model/test_shard_bert.py
View file @
6a88bae4
...
...
@@ -42,7 +42,7 @@ def check_bert(rank, world_size, port):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_bert'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
org_model
,
sharded_model
=
build_model
(
model_fn
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_bloom.py
View file @
6a88bae4
...
...
@@ -42,7 +42,7 @@ def check_bloom(rank, world_size, port):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_bloom'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
org_model
,
sharded_model
=
build_model
(
model_fn
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_gpt2.py
View file @
6a88bae4
...
...
@@ -43,7 +43,7 @@ def check_gpt2(rank, world_size, port):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_gpt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
org_model
,
sharded_model
=
build_model
(
model_fn
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_llama.py
View file @
6a88bae4
...
...
@@ -50,7 +50,7 @@ def check_llama(rank, world_size, port):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_llama'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
org_model
,
sharded_model
=
build_model
(
model_fn
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_opt.py
View file @
6a88bae4
...
...
@@ -54,7 +54,7 @@ def check_OPTModel(rank, world_size, port):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_opt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
org_model
,
sharded_model
=
build_model
(
model_fn
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_t5.py
View file @
6a88bae4
...
...
@@ -42,7 +42,7 @@ def check_t5(rank, world_size, port):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_t5'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
org_model
,
sharded_model
=
build_model
(
model_fn
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_with_torch_ddp.py
0 → 100644
View file @
6a88bae4
import
pytest
import
torch
import
torch.distributed
as
dist
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.cluster
import
DistCoordinator
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.testing
import
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
def
check_shardformer_with_ddp
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_gpt'
)
# create shardformer
# ranks: [0, 1, 2, 3]
# tp ranks = [0, 1], [2, 3]
# dp ranks = [0, 2], [1, 3]
dp_process_group_1
=
dist
.
new_group
([
0
,
2
])
dp_process_group_2
=
dist
.
new_group
([
1
,
3
])
tp_process_group_1
=
dist
.
new_group
([
0
,
1
])
tp_process_group_2
=
dist
.
new_group
([
2
,
3
])
coordinator
=
DistCoordinator
()
if
coordinator
.
rank
in
[
0
,
1
]:
tp_process_group
=
tp_process_group_1
else
:
tp_process_group
=
tp_process_group_2
if
coordinator
.
rank
in
[
0
,
2
]:
dp_process_group
=
dp_process_group_1
else
:
dp_process_group
=
dp_process_group_2
shard_config
=
ShardConfig
(
tensor_parallel_process_group
=
tp_process_group
,
enable_fused_normalization
=
True
)
shardformer
=
ShardFormer
(
shard_config
=
shard_config
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
# create and shard model
model
=
model_fn
().
cuda
()
sharded_model
=
shardformer
.
shard_model
(
model
)
# add ddp
sharded_ddp_model
=
DDP
(
sharded_model
,
process_group
=
dp_process_group
)
# prepare input
data
=
data_gen_fn
()
data
=
{
k
:
v
.
cuda
()
for
k
,
v
in
data
.
items
()}
# switch to train mode
sharded_ddp_model
.
train
()
# run forward
output
=
sharded_ddp_model
(
**
data
)
loss
=
loss_fn
(
output
)
# backward
loss
.
backward
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_gpt2
():
spawn
(
check_shardformer_with_ddp
,
4
)
if
__name__
==
"__main__"
:
test_gpt2
()
test_gpt2
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment