Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
29386a54
Unverified
Commit
29386a54
authored
Mar 08, 2023
by
YuliangLiu0306
Committed by
GitHub
Mar 08, 2023
Browse files
[DTensor] refactor CommSpec (#3034)
parent
ea0b52c1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
501 additions
and
1 deletion
+501
-1
colossalai/tensor/d_tensor/comm_spec.py
colossalai/tensor/d_tensor/comm_spec.py
+310
-0
colossalai/tensor/d_tensor/sharding_spec.py
colossalai/tensor/d_tensor/sharding_spec.py
+1
-1
tests/test_tensor/test_dtensor/test_comm_spec.py
tests/test_tensor/test_dtensor/test_comm_spec.py
+190
-0
No files found.
colossalai/tensor/d_tensor/comm_spec.py
0 → 100644
View file @
29386a54
from
enum
import
Enum
from
typing
import
Dict
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ReduceOp
__all__
=
[
'CollectiveCommPattern'
,
'CommSpec'
,
]
class
CollectiveCommPattern
(
Enum
):
GATHER_FWD_SPLIT_BWD
=
'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD
=
'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD
=
'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD
=
'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD
=
'identity_fwd_all_reduce_bwd'
MIXGATHER_FWD_SPLIT_BWD
=
"mixgather_fwd_split_bwd"
class
CommSpec
:
'''
Communication spec is used to record the communication action. It converts the communication spec
to real action which will be used in runtime. It contains comm_pattern to determine the
communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
'''
def
__init__
(
self
,
comm_pattern
:
CollectiveCommPattern
,
process_groups_dict
:
Dict
,
gather_dim
:
int
=
None
,
shard_dim
:
int
=
None
,
logical_process_axis
:
int
=
None
):
self
.
comm_pattern
=
comm_pattern
self
.
gather_dim
=
gather_dim
self
.
shard_dim
=
shard_dim
self
.
logical_process_axis
=
logical_process_axis
self
.
process_groups_dict
=
process_groups_dict
def
__repr__
(
self
):
res_list
=
[
"CommSpec:("
]
if
self
.
comm_pattern
==
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
:
res_list
.
append
(
f
"comm_pattern:GATHER_FWD_SPLIT_BWD, "
)
res_list
.
append
(
f
"gather_dim:
{
self
.
gather_dim
}
, "
)
res_list
.
append
(
f
"shard_dim:
{
self
.
gather_dim
}
, "
)
res_list
.
append
(
f
"logical_process_axis:
{
self
.
logical_process_axis
}
)"
)
elif
self
.
comm_pattern
==
CollectiveCommPattern
.
ALL2ALL_FWD_ALL2ALL_BWD
:
res_list
.
append
(
f
"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, "
)
res_list
.
append
(
f
"gather_dim:
{
self
.
gather_dim
}
, "
)
res_list
.
append
(
f
"shard_dim:
{
self
.
shard_dim
}
, "
)
res_list
.
append
(
f
"logical_process_axis:
{
self
.
logical_process_axis
}
)"
)
elif
self
.
comm_pattern
==
CollectiveCommPattern
.
SPLIT_FWD_GATHER_BWD
:
res_list
.
append
(
f
"comm_pattern:SPLIT_FWD_GATHER_BWD, "
)
res_list
.
append
(
f
"gather_dim:
{
self
.
gather_dim
}
, "
)
res_list
.
append
(
f
"shard_dim:
{
self
.
shard_dim
}
, "
)
res_list
.
append
(
f
"logical_process_axis:
{
self
.
logical_process_axis
}
)"
)
elif
self
.
comm_pattern
==
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
:
res_list
.
append
(
f
"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, "
)
res_list
.
append
(
f
"logical_process_axis:
{
self
.
logical_process_axis
}
)"
)
elif
self
.
comm_pattern
==
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
:
res_list
.
append
(
f
"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, "
)
res_list
.
append
(
f
"logical_process_axis:
{
self
.
logical_process_axis
}
)"
)
return
''
.
join
(
res_list
)
def
covert_spec_to_action
(
self
,
tensor
):
'''
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
if
self
.
comm_pattern
in
pattern_to_func_dict
:
tensor
=
pattern_to_func_dict
[
self
.
comm_pattern
](
tensor
,
self
)
else
:
tensor
=
tensor
return
tensor
def
_all_gather
(
tensor
:
torch
.
Tensor
,
comm_spec
:
CommSpec
):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
process_groups_list
=
comm_spec
.
process_groups_dict
[
comm_spec
.
logical_process_axis
]
for
rank_list
,
process_group
in
process_groups_list
:
if
dist
.
get_rank
()
in
rank_list
:
tensor_list
=
[
torch
.
zeros
(
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
len
(
rank_list
))
]
# without this contiguous operation, the all gather may get some unexpected results.
tensor
=
tensor
.
contiguous
()
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
process_group
)
output
=
torch
.
cat
(
tuple
(
tensor_list
),
comm_spec
.
gather_dim
).
contiguous
()
return
output
def
_split
(
tensor
:
torch
.
Tensor
,
comm_spec
:
CommSpec
):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
process_groups_list
=
comm_spec
.
process_groups_dict
[
comm_spec
.
logical_process_axis
]
for
rank_list
,
_
in
process_groups_list
:
if
dist
.
get_rank
()
in
rank_list
:
dim
=
comm_spec
.
shard_dim
length
=
tensor
.
shape
[
comm_spec
.
shard_dim
]
//
len
(
rank_list
)
start
=
length
*
rank_list
.
index
(
dist
.
get_rank
())
output
=
torch
.
narrow
(
tensor
,
dim
,
start
,
length
).
contiguous
()
return
output
def
_all_to_all
(
tensor
:
torch
.
Tensor
,
comm_spec
:
CommSpec
):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
process_groups_list
=
comm_spec
.
process_groups_dict
[
comm_spec
.
logical_process_axis
]
for
rank_list
,
process_group
in
process_groups_list
:
if
dist
.
get_rank
()
in
rank_list
:
new_shape
=
list
(
tensor
.
shape
)
new_shape
[
comm_spec
.
shard_dim
]
=
new_shape
[
comm_spec
.
shard_dim
]
//
len
(
rank_list
)
new_shape
=
torch
.
Size
(
new_shape
)
output_tensor_list
=
[
torch
.
zeros
(
new_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
len
(
rank_list
))
]
dim
=
comm_spec
.
shard_dim
length
=
tensor
.
shape
[
comm_spec
.
shard_dim
]
//
len
(
rank_list
)
input_tensor_list
=
[
torch
.
narrow
(
tensor
,
dim
,
length
*
i
,
length
).
contiguous
()
for
i
in
range
(
len
(
rank_list
))
]
group
=
process_group
dist
.
all_to_all
(
output_tensor_list
,
input_tensor_list
,
group
)
output
=
torch
.
cat
(
tuple
(
output_tensor_list
),
comm_spec
.
gather_dim
).
contiguous
()
return
output
def
_all_reduce
(
tensor
:
torch
.
Tensor
,
comm_spec
:
CommSpec
,
async_op
:
bool
=
False
):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list
=
comm_spec
.
process_groups_dict
[
comm_spec
.
logical_process_axis
]
for
rank_list
,
process_group
in
process_groups_list
:
if
dist
.
get_rank
()
in
rank_list
:
if
not
tensor
.
is_contiguous
():
tensor
=
tensor
.
contiguous
()
dist
.
all_reduce
(
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
async_op
)
return
tensor
class
_ReduceGrad
(
torch
.
autograd
.
Function
):
"""
A customized communication operation which forward is an identity operation,
backward is all_reduce operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
input_
@
staticmethod
def
forward
(
ctx
,
input_
,
comm_spec
):
ctx
.
comm_spec
=
comm_spec
return
input_
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_all_reduce
(
grad_output
,
ctx
.
comm_spec
),
None
class
_ReduceInput
(
torch
.
autograd
.
Function
):
"""
A customized communication operation which forward is all_reduce operation,
backward is an identity operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_all_reduce
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
comm_spec
):
return
_all_reduce
(
input_
,
comm_spec
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
grad_output
,
None
class
_SplitForwardGatherBackward
(
torch
.
autograd
.
Function
):
"""
A customized communication operation which forward is split operation,
backward is an all gather operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
comm_spec
):
ctx
.
comm_spec
=
comm_spec
return
_split
(
input_
,
comm_spec
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_all_gather
(
grad_output
,
ctx
.
comm_spec
),
None
class
_GatherForwardSplitBackward
(
torch
.
autograd
.
Function
):
"""
A customized communication operation which forward is an all gather operation,
backward is split operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_all_gather
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
comm_spec
):
ctx
.
comm_spec
=
comm_spec
return
_all_gather
(
input_
,
comm_spec
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
,
ctx
.
comm_spec
),
None
class
_AllToAll
(
torch
.
autograd
.
Function
):
"""
A customized communication operation which forward is an all to all operation,
backward is an all to all operation.
Args:
input_: input matrix.
comm_spec: comm_spec will give information like process group, rank list, etc.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_all_to_all
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
comm_spec
):
output
=
_all_to_all
(
input_
,
comm_spec
)
comm_spec_for_backward
=
CommSpec
(
comm_pattern
=
comm_spec
.
comm_pattern
,
process_groups_dict
=
comm_spec
.
process_groups_dict
,
gather_dim
=
comm_spec
.
shard_dim
,
shard_dim
=
comm_spec
.
gather_dim
,
logical_process_axis
=
comm_spec
.
logical_process_axis
)
ctx
.
comm_spec
=
comm_spec_for_backward
return
output
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
return
_all_to_all
(
grad_outputs
,
ctx
.
comm_spec
),
None
def
reduce_grad
(
input_
,
comm_spec
):
return
_ReduceGrad
.
apply
(
input_
,
comm_spec
)
def
reduce_input
(
input_
,
comm_spec
):
return
_ReduceInput
.
apply
(
input_
,
comm_spec
)
def
split_forward_gather_backward
(
input_
,
comm_spec
):
return
_SplitForwardGatherBackward
.
apply
(
input_
,
comm_spec
)
def
gather_forward_split_backward
(
input_
,
comm_spec
):
return
_GatherForwardSplitBackward
.
apply
(
input_
,
comm_spec
)
def
all_to_all
(
input_
,
comm_spec
):
return
_AllToAll
.
apply
(
input_
,
comm_spec
)
pattern_to_func_dict
=
{
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
:
gather_forward_split_backward
,
CollectiveCommPattern
.
ALL2ALL_FWD_ALL2ALL_BWD
:
all_to_all
,
CollectiveCommPattern
.
SPLIT_FWD_GATHER_BWD
:
split_forward_gather_backward
,
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
:
reduce_input
,
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
:
reduce_grad
,
}
colossalai/tensor/d_tensor/sharding_spec.py
View file @
29386a54
...
@@ -171,7 +171,7 @@ class ShardingSpec:
...
@@ -171,7 +171,7 @@ class ShardingSpec:
raise
ShardingOutOfIndexError
(
raise
ShardingOutOfIndexError
(
f
'sharding_sequence should have
{
self
.
dims
}
elements, but got index
{
len
(
self
.
sharding_sequence
)
}
.'
)
f
'sharding_sequence should have
{
self
.
dims
}
elements, but got index
{
len
(
self
.
sharding_sequence
)
}
.'
)
if
max
(
list
(
self
.
dim_partition_dict
.
keys
()))
>=
self
.
dims
:
if
list
(
self
.
dim_partition_dict
.
keys
())
and
max
(
list
(
self
.
dim_partition_dict
.
keys
()))
>=
self
.
dims
:
raise
ShardingOutOfIndexError
(
raise
ShardingOutOfIndexError
(
f
'the key of dim_partition_dict should be less than
{
self
.
dims
}
, but got
{
max
(
list
(
self
.
dim_partition_dict
.
keys
()))
}
.'
f
'the key of dim_partition_dict should be less than
{
self
.
dims
}
, but got
{
max
(
list
(
self
.
dim_partition_dict
.
keys
()))
}
.'
)
)
...
...
tests/test_tensor/test_dtensor/test_comm_spec.py
0 → 100644
View file @
29386a54
from
functools
import
partial
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
torch.distributed
import
ReduceOp
from
colossalai.core
import
global_context
as
gpc
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.d_tensor.comm_spec
import
CollectiveCommPattern
,
CommSpec
from
colossalai.tensor.d_tensor.sharding_spec
import
ShardingSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
def
check_all_gather
(
process_groups_dict
,
rank
):
# tensor to comm
if
rank
in
(
0
,
2
):
sharded_tensor_to_comm
=
torch
.
ones
(
2
,
2
).
cuda
()
else
:
sharded_tensor_to_comm
=
torch
.
zeros
(
2
,
2
).
cuda
()
# tensor to check
tensor_to_check
=
torch
.
cat
((
torch
.
ones
(
2
,
2
),
torch
.
zeros
(
2
,
2
)),
1
).
cuda
()
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
process_groups_dict
,
gather_dim
=
1
,
logical_process_axis
=
1
)
sharded_tensor_to_comm
=
sharded_tensor_to_comm
=
comm_spec
.
covert_spec_to_action
(
sharded_tensor_to_comm
)
assert
sharded_tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_shard
(
process_groups_dict
,
rank
):
# tensor to comm
sharded_tensor_to_comm_0
=
torch
.
zeros
(
2
,
2
).
cuda
()
sharded_tensor_to_comm_1
=
torch
.
ones
(
2
,
2
).
cuda
()
# tensor([[0., 0., 1., 1.],
# [0., 0., 1., 1.]])
tensor_to_shard
=
torch
.
cat
((
sharded_tensor_to_comm_0
,
sharded_tensor_to_comm_1
),
1
)
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
SPLIT_FWD_GATHER_BWD
,
process_groups_dict
,
shard_dim
=
1
,
logical_process_axis
=
1
)
tensor_to_shard
=
comm_spec
.
covert_spec_to_action
(
tensor_to_shard
)
if
rank
in
(
0
,
2
):
assert
tensor_to_shard
.
equal
(
sharded_tensor_to_comm_0
)
if
rank
in
(
1
,
3
):
assert
tensor_to_shard
.
equal
(
sharded_tensor_to_comm_1
)
def
check_all_to_all
(
process_groups_dict
,
rank
):
# tensor to comm
if
rank
in
(
0
,
1
):
sharded_tensor_0
=
torch
.
zeros
(
2
,
1
)
sharded_tensor_1
=
torch
.
ones
(
2
,
1
)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm
=
torch
.
cat
((
sharded_tensor_0
,
sharded_tensor_1
),
1
).
cuda
()
if
rank
in
(
2
,
3
):
sharded_tensor_0
=
torch
.
ones
(
2
,
1
)
*
2
sharded_tensor_1
=
torch
.
ones
(
2
,
1
)
*
3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm
=
torch
.
cat
((
sharded_tensor_0
,
sharded_tensor_1
),
1
).
cuda
()
if
rank
in
(
0
,
1
):
# tensor([[0.],
# [0.],
# [2.],
# [2.]])
tensor_to_check
=
torch
.
tensor
([[
0
],
[
0
],
[
2
],
[
2
]],
dtype
=
tensor_to_comm
.
dtype
).
cuda
()
if
rank
in
(
2
,
3
):
# tensor([[1.],
# [1.],
# [3.],
# [3.]])
tensor_to_check
=
torch
.
tensor
([[
1
],
[
1
],
[
3
],
[
3
]],
dtype
=
tensor_to_comm
.
dtype
).
cuda
()
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
ALL2ALL_FWD_ALL2ALL_BWD
,
process_groups_dict
,
gather_dim
=
0
,
shard_dim
=
1
,
logical_process_axis
=
0
)
tensor_to_comm
=
comm_spec
.
covert_spec_to_action
(
tensor_to_comm
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_all_reduce_fwd
(
process_groups_dict
,
rank
):
# tensor to comm
tensor_to_comm
=
torch
.
ones
(
2
,
2
).
cuda
()
*
rank
# reduce through logical process axis 0
# tensor to check
if
rank
in
(
0
,
2
):
# tensor([[2., 2.],
# [2., 2.]])
tensor_to_check
=
torch
.
tensor
([[
2
,
2
],
[
2
,
2
]],
dtype
=
tensor_to_comm
.
dtype
).
cuda
()
if
rank
in
(
1
,
3
):
# tensor([[4., 4.],
# [4., 4.]])
tensor_to_check
=
torch
.
tensor
([[
4
,
4
],
[
4
,
4
]],
dtype
=
tensor_to_comm
.
dtype
).
cuda
()
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
process_groups_dict
,
logical_process_axis
=
0
)
tensor_to_comm
=
comm_spec
.
covert_spec_to_action
(
tensor_to_comm
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_all_reduce_bwd
(
process_groups_dict
,
rank
):
# tensor to comm
tensor_to_comm
=
torch
.
ones
(
2
,
2
).
cuda
()
*
rank
tensor_to_check
=
torch
.
ones
(
2
,
2
).
cuda
()
*
rank
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
process_groups_dict
,
logical_process_axis
=
0
)
tensor_to_comm
=
comm_spec
.
covert_spec_to_action
(
tensor_to_comm
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_all_reduce_in_flatten_device_mesh
(
process_groups_dict
,
rank
):
# tensor to comm
tensor_to_comm
=
torch
.
ones
(
2
,
2
).
cuda
()
*
rank
# reduce through logical process axis 0 at flatten device mesh
# tensor to check
# tensor([[6., 6.],
# [6., 6.]])
tensor_to_check
=
torch
.
tensor
([[
6
,
6
],
[
6
,
6
]],
dtype
=
tensor_to_comm
.
dtype
).
cuda
()
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec
=
CommSpec
(
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
process_groups_dict
,
logical_process_axis
=
0
)
tensor_to_comm
=
comm_spec
.
covert_spec_to_action
(
tensor_to_comm
)
assert
tensor_to_comm
.
equal
(
tensor_to_check
)
def
check_comm
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
assert
rank
==
gpc
.
get_global_rank
()
mesh_shape
=
(
2
,
2
)
# [[0, 1,
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
process_groups_dict
=
device_mesh
.
process_groups_dict
# test all gather
check_all_gather
(
process_groups_dict
,
rank
)
# test shard
check_shard
(
process_groups_dict
,
rank
)
# test all to all
check_all_to_all
(
process_groups_dict
,
rank
)
# test all reduce
check_all_reduce_fwd
(
process_groups_dict
,
rank
)
check_all_reduce_bwd
(
process_groups_dict
,
rank
)
flatten_process_groups_dict
=
device_mesh
.
flatten_device_mesh
.
process_groups_dict
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh
(
flatten_process_groups_dict
,
rank
)
gpc
.
destroy
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_comm_spec
():
world_size
=
4
run_func
=
partial
(
check_comm
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_comm_spec
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment