Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4b03c25f
"git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "44cc93c73c0264280c2e799b38aafbed2a31d817"
Unverified
Commit
4b03c25f
authored
Aug 25, 2022
by
YuliangLiu0306
Committed by
GitHub
Aug 25, 2022
Browse files
[tensor]add 1D device mesh (#1492)
parent
b8d0e39e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
13 deletions
+66
-13
colossalai/device/device_mesh.py
colossalai/device/device_mesh.py
+25
-2
colossalai/tensor/shape_consistency.py
colossalai/tensor/shape_consistency.py
+14
-9
tests/test_tensor/test_comm_spec_apply.py
tests/test_tensor/test_comm_spec_apply.py
+27
-1
tests/test_tensor/test_shape_consistency_apply.py
tests/test_tensor/test_shape_consistency_apply.py
+0
-1
No files found.
colossalai/device/device_mesh.py
View file @
4b03c25f
...
@@ -25,7 +25,13 @@ class DeviceMesh:
...
@@ -25,7 +25,13 @@ class DeviceMesh:
(default: False)
(default: False)
"""
"""
def
__init__
(
self
,
physical_mesh_id
,
mesh_shape
,
mesh_alpha
=
None
,
mesh_beta
=
None
,
init_process_group
=
False
):
def
__init__
(
self
,
physical_mesh_id
,
mesh_shape
,
mesh_alpha
=
None
,
mesh_beta
=
None
,
init_process_group
=
False
,
need_flatten
=
True
):
self
.
physical_mesh_id
=
physical_mesh_id
self
.
physical_mesh_id
=
physical_mesh_id
self
.
mesh_shape
=
mesh_shape
self
.
mesh_shape
=
mesh_shape
self
.
_logical_mesh_id
=
self
.
physical_mesh_id
.
reshape
(
self
.
mesh_shape
)
self
.
_logical_mesh_id
=
self
.
physical_mesh_id
.
reshape
(
self
.
mesh_shape
)
...
@@ -39,8 +45,12 @@ class DeviceMesh:
...
@@ -39,8 +45,12 @@ class DeviceMesh:
mesh_beta
=
[
1
]
*
len
(
self
.
mesh_shape
)
mesh_beta
=
[
1
]
*
len
(
self
.
mesh_shape
)
self
.
mesh_alpha
=
tuple
(
mesh_alpha
)
self
.
mesh_alpha
=
tuple
(
mesh_alpha
)
self
.
mesh_beta
=
tuple
(
mesh_beta
)
self
.
mesh_beta
=
tuple
(
mesh_beta
)
if
init_process_group
:
self
.
init_process_group
=
init_process_group
self
.
need_flatten
=
need_flatten
if
self
.
init_process_group
:
self
.
process_groups_dict
=
self
.
create_process_groups_for_logical_mesh
()
self
.
process_groups_dict
=
self
.
create_process_groups_for_logical_mesh
()
if
self
.
need_flatten
:
self
.
flatten_device_mesh
=
self
.
flatten
()
@
property
@
property
def
shape
(
self
):
def
shape
(
self
):
...
@@ -54,6 +64,19 @@ class DeviceMesh:
...
@@ -54,6 +64,19 @@ class DeviceMesh:
def
logical_mesh_id
(
self
):
def
logical_mesh_id
(
self
):
return
self
.
_logical_mesh_id
return
self
.
_logical_mesh_id
def
flatten
(
self
):
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
flatten_mesh_shape_size
=
len
(
self
.
mesh_shape
)
flatten_mesh_shape
=
[
self
.
num_devices
]
return
DeviceMesh
(
self
.
physical_mesh_id
,
tuple
(
flatten_mesh_shape
),
mesh_alpha
=
[
max
(
self
.
mesh_alpha
)]
*
(
flatten_mesh_shape_size
-
1
),
mesh_beta
=
[
min
(
self
.
mesh_beta
)]
*
(
flatten_mesh_shape_size
-
1
),
init_process_group
=
self
.
init_process_group
,
need_flatten
=
False
)
def
_global_rank_to_logical_rank_map
(
self
,
tensor
,
index_list
):
def
_global_rank_to_logical_rank_map
(
self
,
tensor
,
index_list
):
'''
'''
This method is a helper function to build convert_map recursively.
This method is a helper function to build convert_map recursively.
...
...
colossalai/tensor/shape_consistency.py
View file @
4b03c25f
...
@@ -3,6 +3,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
...
@@ -3,6 +3,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from
colossalai.tensor.utils
import
all_gather_simulator
,
all_to_all_simulator
,
shard_simulator
from
colossalai.tensor.utils
import
all_gather_simulator
,
all_to_all_simulator
,
shard_simulator
from
enum
import
Enum
from
enum
import
Enum
from
copy
import
deepcopy
from
copy
import
deepcopy
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
math
import
math
from
functools
import
reduce
from
functools
import
reduce
...
@@ -29,9 +30,9 @@ class CommSpec:
...
@@ -29,9 +30,9 @@ class CommSpec:
Argument:
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action.
sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action.
gather_dim(int,
o
ptional): The gather_dim of the tensor will be gathered.
gather_dim(int,
O
ptional): The gather_dim of the tensor will be gathered.
shard_dim(int,
o
ptional): The shard_dim of the tensor will be sharded.
shard_dim(int,
O
ptional): The shard_dim of the tensor will be sharded.
logical_process_axis(
int, o
ptional): The mesh_dim to implement the communication action.
logical_process_axis(
Union(int, List[int]), O
ptional): The mesh_dim to implement the communication action.
'''
'''
def
__init__
(
self
,
comm_pattern
,
sharding_spec
,
gather_dim
=
None
,
shard_dim
=
None
,
logical_process_axis
=
None
):
def
__init__
(
self
,
comm_pattern
,
sharding_spec
,
gather_dim
=
None
,
shard_dim
=
None
,
logical_process_axis
=
None
):
...
@@ -40,6 +41,11 @@ class CommSpec:
...
@@ -40,6 +41,11 @@ class CommSpec:
self
.
gather_dim
=
gather_dim
self
.
gather_dim
=
gather_dim
self
.
shard_dim
=
shard_dim
self
.
shard_dim
=
shard_dim
self
.
logical_process_axis
=
logical_process_axis
self
.
logical_process_axis
=
logical_process_axis
if
isinstance
(
self
.
logical_process_axis
,
list
):
self
.
device_mesh
=
self
.
sharding_spec
.
device_mesh
.
flatten_device_mesh
self
.
logical_process_axis
=
0
else
:
self
.
device_mesh
=
self
.
sharding_spec
.
device_mesh
def
__repr__
(
self
):
def
__repr__
(
self
):
res_list
=
[
"CommSpec:("
]
res_list
=
[
"CommSpec:("
]
...
@@ -70,11 +76,11 @@ class CommSpec:
...
@@ -70,11 +76,11 @@ class CommSpec:
'''
'''
comm_size
=
reduce
(
operator
.
mul
,
self
.
sharding_spec
.
get_sharded_shape_per_device
(),
1
)
comm_size
=
reduce
(
operator
.
mul
,
self
.
sharding_spec
.
get_sharded_shape_per_device
(),
1
)
if
self
.
comm_pattern
==
CollectiveCommPattern
.
ALLGATHER
:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
ALLGATHER
:
return
self
.
sharding_spec
.
device_mesh
.
all_gather_cost
(
comm_size
,
self
.
logical_process_axis
)
return
self
.
device_mesh
.
all_gather_cost
(
comm_size
,
self
.
logical_process_axis
)
if
self
.
comm_pattern
==
CollectiveCommPattern
.
ALLTOALL
:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
ALLTOALL
:
return
self
.
sharding_spec
.
device_mesh
.
all_to_all_cost
(
comm_size
,
self
.
logical_process_axis
)
return
self
.
device_mesh
.
all_to_all_cost
(
comm_size
,
self
.
logical_process_axis
)
if
self
.
comm_pattern
==
CollectiveCommPattern
.
ALLREDUCE
:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
ALLREDUCE
:
return
self
.
sharding_spec
.
device_mesh
.
all_reduce_cost
(
comm_size
,
self
.
logical_process_axis
)
return
self
.
device_mesh
.
all_reduce_cost
(
comm_size
,
self
.
logical_process_axis
)
if
self
.
comm_pattern
==
CollectiveCommPattern
.
SHARD
:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
SHARD
:
return
0
return
0
raise
RuntimeError
(
f
"Could not find a matching CollectiveCommPattern for
{
self
.
comm_pattern
}
."
)
raise
RuntimeError
(
f
"Could not find a matching CollectiveCommPattern for
{
self
.
comm_pattern
}
."
)
...
@@ -87,15 +93,14 @@ class CommSpec:
...
@@ -87,15 +93,14 @@ class CommSpec:
Argument:
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
'''
device_mesh
=
self
.
sharding_spec
.
device_mesh
process_groups_list
=
self
.
device_mesh
.
process_groups_dict
[
self
.
logical_process_axis
]
process_groups_list
=
device_mesh
.
process_groups_dict
[
self
.
logical_process_axis
]
if
self
.
comm_pattern
==
CollectiveCommPattern
.
ALLGATHER
:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
ALLGATHER
:
for
rank_list
,
process_group
in
process_groups_list
:
for
rank_list
,
process_group
in
process_groups_list
:
if
dist
.
get_rank
()
in
rank_list
:
if
dist
.
get_rank
()
in
rank_list
:
tensor_list
=
[
tensor_list
=
[
torch
.
zeros
(
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
torch
.
zeros
(
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
self
.
sharding_spec
.
device_mesh
.
mesh_shape
[
self
.
logical_process_axis
])
for
_
in
range
(
self
.
device_mesh
.
mesh_shape
[
self
.
logical_process_axis
])
]
]
tensor
=
tensor
tensor
=
tensor
group
=
process_group
group
=
process_group
...
...
tests/test_tensor/test_comm_spec_apply.py
View file @
4b03c25f
...
@@ -133,13 +133,36 @@ def check_all_reduce(device_mesh, rank):
...
@@ -133,13 +133,36 @@ def check_all_reduce(device_mesh, rank):
# device_mesh_shape: (2, 2)
# device_mesh_shape: (2, 2)
sharding_spec
=
ShardingSpec
(
device_mesh
,
tensor_to_comm
.
shape
,
dim_partition_dict
=
dim_partition_dict
)
sharding_spec
=
ShardingSpec
(
device_mesh
,
tensor_to_comm
.
shape
,
dim_partition_dict
=
dim_partition_dict
)
#
CommSpec:
CommSpec:(comm_pattern:all_reduce, logical_process_axis:0)
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:0)
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
ALLREDUCE
,
sharding_spec
,
logical_process_axis
=
0
)
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
ALLREDUCE
,
sharding_spec
,
logical_process_axis
=
0
)
comm_spec
.
covert_spec_to_action
(
tensor_to_comm
)
comm_spec
.
covert_spec_to_action
(
tensor_to_comm
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_all_reduce_in_flatten_device_mesh
(
device_mesh
,
rank
):
# tensor to comm
tensor_to_comm
=
torch
.
ones
(
2
,
2
).
cuda
()
*
rank
# reduce through logical process axis 0 at flatten device mesh
# tensor to check
# tensor([[6., 6.],
# [6., 6.]])
tensor_to_check
=
torch
.
tensor
([[
6
,
6
],
[
6
,
6
]],
dtype
=
tensor_to_comm
.
dtype
).
cuda
()
dim_partition_dict
=
{}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec
=
ShardingSpec
(
device_mesh
,
tensor_to_comm
.
shape
,
dim_partition_dict
=
dim_partition_dict
)
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
ALLREDUCE
,
sharding_spec
,
logical_process_axis
=
[
0
,
1
])
comm_spec
.
covert_spec_to_action
(
tensor_to_comm
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_comm
(
rank
,
world_size
,
port
):
def
check_comm
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
...
@@ -162,6 +185,9 @@ def check_comm(rank, world_size, port):
...
@@ -162,6 +185,9 @@ def check_comm(rank, world_size, port):
# test all reduce
# test all reduce
check_all_reduce
(
device_mesh
,
rank
)
check_all_reduce
(
device_mesh
,
rank
)
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh
(
device_mesh
,
rank
)
gpc
.
destroy
()
gpc
.
destroy
()
...
...
tests/test_tensor/test_shape_consistency_apply.py
View file @
4b03c25f
...
@@ -64,7 +64,6 @@ def check_apply(rank, world_size, port):
...
@@ -64,7 +64,6 @@ def check_apply(rank, world_size, port):
tensor_to_comm
.
sharding_spec
=
sharding_spec_source
tensor_to_comm
.
sharding_spec
=
sharding_spec_source
shape_consistency_manager
.
apply
(
tensor_to_comm
,
sharding_spec_target
)
shape_consistency_manager
.
apply
(
tensor_to_comm
,
sharding_spec_target
)
print
(
tensor_to_comm
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
assert
str
(
tensor_to_comm
.
sharding_spec
.
sharding_sequence
)
==
str
(
sharding_spec_target
.
sharding_sequence
)
assert
str
(
tensor_to_comm
.
sharding_spec
.
sharding_sequence
)
==
str
(
sharding_spec_target
.
sharding_sequence
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment