Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
536560cc
Unverified
Commit
536560cc
authored
Dec 14, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 14, 2022
Browse files
[autoparallel] implement softmax handler (#2132)
parent
c89c66a8
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
349 additions
and
4 deletions
+349
-4
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+2
-1
colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
...uto_parallel/tensor_shard/node_handler/softmax_handler.py
+55
-0
colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
...o_parallel/tensor_shard/node_handler/strategy/__init__.py
+2
-1
colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
...l/tensor_shard/node_handler/strategy/softmax_generator.py
+104
-0
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
...el/tensor_shard/node_handler/unary_elementwise_handler.py
+0
-2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py
...st_tensor_shard/test_node_handler/test_softmax_handler.py
+186
-0
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
536560cc
...
@@ -15,6 +15,7 @@ from .output_handler import OuputHandler
...
@@ -15,6 +15,7 @@ from .output_handler import OuputHandler
from
.placeholder_handler
import
PlacehodlerHandler
from
.placeholder_handler
import
PlacehodlerHandler
from
.registry
import
operator_registry
from
.registry
import
operator_registry
from
.reshape_handler
import
ReshapeHandler
from
.reshape_handler
import
ReshapeHandler
from
.softmax_handler
import
SoftmaxHandler
from
.sum_handler
import
SumHandler
from
.sum_handler
import
SumHandler
from
.tensor_constructor_handler
import
TensorConstructorHandler
from
.tensor_constructor_handler
import
TensorConstructorHandler
from
.unary_elementwise_handler
import
UnaryElementwiseHandler
from
.unary_elementwise_handler
import
UnaryElementwiseHandler
...
@@ -26,5 +27,5 @@ __all__ = [
...
@@ -26,5 +27,5 @@ __all__ = [
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'GetItemHandler'
,
'GetattrHandler'
,
'ViewHandler'
,
'PermuteHandler'
,
'TensorConstructorHandler'
,
'GetItemHandler'
,
'GetattrHandler'
,
'ViewHandler'
,
'PermuteHandler'
,
'TensorConstructorHandler'
,
'EmbeddingModuleHandler'
,
'EmbeddingFunctionHandler'
,
'SumHandler'
'EmbeddingModuleHandler'
,
'EmbeddingFunctionHandler'
,
'SumHandler'
,
'SoftmaxHandler'
]
]
colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
0 → 100644
View file @
536560cc
from
typing
import
Dict
,
List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
SoftmaxGenerator
,
StrategyGenerator
__all__
=
[
'SoftmaxHandler'
]
@
operator_registry
.
register
(
torch
.
nn
.
Softmax
)
@
operator_registry
.
register
(
torch
.
nn
.
functional
.
softmax
)
class
SoftmaxHandler
(
NodeHandler
):
"""
A SoftmaxHandler which deals with the sharding strategies for
torch.nn.Softmax or torch.nn.functional.softmax.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
SoftmaxGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
softmax_dim
=
self
.
node
.
kwargs
[
'dim'
]
num_dims
=
self
.
node
.
args
[
0
].
_meta_data
.
dim
()
# recover negative value to positive
if
softmax_dim
<
0
:
softmax_dim
+=
num_dims
physical_dim_operand
=
OperationData
(
name
=
'softmax_dim'
,
type
=
OperationDataType
.
ARG
,
data
=
softmax_dim
)
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"softmax_dim"
:
physical_dim_operand
,
"output"
:
physical_output_operand
}
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
View file @
536560cc
...
@@ -15,6 +15,7 @@ from .normal_pooling_generator import NormalPoolStrategyGenerator
...
@@ -15,6 +15,7 @@ from .normal_pooling_generator import NormalPoolStrategyGenerator
from
.output_generator
import
OutputGenerator
from
.output_generator
import
OutputGenerator
from
.placeholder_generator
import
PlaceholderGenerator
from
.placeholder_generator
import
PlaceholderGenerator
from
.reshape_generator
import
ReshapeGenerator
from
.reshape_generator
import
ReshapeGenerator
from
.softmax_generator
import
SoftmaxGenerator
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
from
.sum_generator
import
SumGenerator
from
.sum_generator
import
SumGenerator
from
.tensor_constructor_generator
import
TensorConstructorGenerator
from
.tensor_constructor_generator
import
TensorConstructorGenerator
...
@@ -27,5 +28,5 @@ __all__ = [
...
@@ -27,5 +28,5 @@ __all__ = [
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'TensorTupleStrategyGenerator'
,
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'TensorTupleStrategyGenerator'
,
'LayerNormGenerator'
,
'ReshapeGenerator'
,
'PlaceholderGenerator'
,
'OutputGenerator'
,
'WhereGenerator'
,
'LayerNormGenerator'
,
'ReshapeGenerator'
,
'PlaceholderGenerator'
,
'OutputGenerator'
,
'WhereGenerator'
,
'ReshapeGenerator'
,
'NormalPoolStrategyGenerator'
,
'BinaryElementwiseStrategyGenerator'
,
'GetattrGenerator'
,
'ReshapeGenerator'
,
'NormalPoolStrategyGenerator'
,
'BinaryElementwiseStrategyGenerator'
,
'GetattrGenerator'
,
'TensorConstructorGenerator'
,
'EmbeddingStrategyGenerator'
,
'SumGenerator'
'TensorConstructorGenerator'
,
'EmbeddingStrategyGenerator'
,
'SumGenerator'
,
'SoftmaxGenerator'
]
]
colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
0 → 100644
View file @
536560cc
import
copy
import
operator
from
functools
import
reduce
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator
import
FollowingStrategyGenerator
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.auto_parallel.tensor_shard.utils
import
(
check_keep_sharding_status
,
detect_reshape_mapping
,
infer_output_dim_partition_dict
,
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
__all__
=
[
'SoftmaxGenerator'
]
class
SoftmaxGenerator
(
FollowingStrategyGenerator
):
"""
SoftmaxGenerator is used to generate strategies for torch.nn.Softmax or F.softmax.
"""
def
validate
(
self
)
->
bool
:
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
Compute the computation cost per device with this specific strategy.
'''
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'input'
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'output'
]].
get_sharded_shape_per_device
()
input_size_product
=
reduce
(
operator
.
mul
,
sharded_input_shape
)
output_size_product
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
forward_compute_cost
=
output_size_product
*
2
backward_compute_cost
=
input_size_product
total_compute_cost
=
forward_compute_cost
+
backward_compute_cost
compute_cost
=
TrainCycleItem
(
fwd
=
forward_compute_cost
,
bwd
=
backward_compute_cost
,
total
=
total_compute_cost
)
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping
=
{
'input'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'output'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
.
pop
(
"output"
)
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
fwd_parameter_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
fwd_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
,
parameter
=
fwd_parameter_cost
)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
bwd_parameter_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_input
=
copy
.
deepcopy
(
input_sharding_spec
.
dim_partition_dict
)
softmax_dim
=
self
.
op_data
[
'softmax_dim'
].
data
if
softmax_dim
in
dim_partition_dict_for_input
:
recover_dims
=
dim_partition_dict_for_input
.
pop
(
softmax_dim
)
dim_partition_dict_for_output
=
copy
.
deepcopy
(
dim_partition_dict_for_input
)
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
View file @
536560cc
...
@@ -16,8 +16,6 @@ __all__ = ['UnaryElementwiseHandler']
...
@@ -16,8 +16,6 @@ __all__ = ['UnaryElementwiseHandler']
@
operator_registry
.
register
(
torch
.
nn
.
ReLU
)
@
operator_registry
.
register
(
torch
.
nn
.
ReLU
)
@
operator_registry
.
register
(
torch
.
nn
.
Tanh
)
@
operator_registry
.
register
(
torch
.
nn
.
Tanh
)
@
operator_registry
.
register
(
torch
.
tanh
)
@
operator_registry
.
register
(
torch
.
tanh
)
# TODO: softmax need to be relocated
@
operator_registry
.
register
(
torch
.
nn
.
functional
.
softmax
)
@
operator_registry
.
register
(
torch
.
nn
.
modules
.
dropout
.
Dropout
)
@
operator_registry
.
register
(
torch
.
nn
.
modules
.
dropout
.
Dropout
)
@
operator_registry
.
register
(
torch
.
Tensor
.
contiguous
)
@
operator_registry
.
register
(
torch
.
Tensor
.
contiguous
)
@
operator_registry
.
register
(
torch
.
nn
.
functional
.
dropout
)
@
operator_registry
.
register
(
torch
.
nn
.
functional
.
dropout
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py
0 → 100644
View file @
536560cc
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
colossalai.auto_parallel.tensor_shard.node_handler.linear_handler
import
LinearFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler
import
SoftmaxHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
class
LinearSplitModel
(
nn
.
Module
):
def
__init__
(
self
,
softmax_dim
):
super
().
__init__
()
self
.
softmax_dim
=
softmax_dim
def
forward
(
self
,
input
,
other
):
linear_node
=
F
.
linear
(
input
,
other
,
bias
=
None
)
softmax_node
=
F
.
softmax
(
linear_node
,
self
.
softmax_dim
)
return
softmax_node
def
check_split_handler
(
rank
,
softmax_dim
,
model_cls
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
model_cls
(
softmax_dim
=
softmax_dim
).
cuda
()
input
=
torch
.
rand
(
8
,
16
,
64
,
32
).
to
(
'cuda'
)
other
=
torch
.
rand
(
64
,
32
).
to
(
'cuda'
)
# index of linear node in computation graph
node_index
=
2
# total number of linear strategies
strategy_number
=
23
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
[
input
,
other
],
meta_arg_names
=
[
'input'
,
'other'
],
node_type
=
'following'
)
tracer
=
ColoTracer
()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {})
# return split
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
8
,
16
,
64
,
32
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
64
,
32
).
to
(
'meta'
),
})
gm
=
ColoGraphModule
(
model
,
graph
)
previous_mod_node
=
list
(
graph
.
nodes
)[
2
]
split_node
=
list
(
graph
.
nodes
)[
3
]
split_strategies_vector
=
StrategiesVector
(
split_node
)
previous_strategies_vector
=
StrategiesVector
(
previous_mod_node
)
# build handler
assert
len
(
previous_strategies_vector
)
==
0
linear_handler
=
LinearFunctionHandler
(
node
=
previous_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
previous_strategies_vector
)
linear_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
previous_mod_node
,
'strategies_vector'
,
previous_strategies_vector
)
softmax_handler
=
SoftmaxHandler
(
node
=
split_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
split_strategies_vector
)
softmax_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
# check operation data mapping
mapping
=
softmax_handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"linear"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
8
,
16
,
64
,
64
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
8
,
16
,
64
,
64
])
assert
mapping
[
'softmax_dim'
].
name
==
"softmax_dim"
assert
mapping
[
'softmax_dim'
].
data
==
softmax_dim
assert
mapping
[
'softmax_dim'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'output'
].
name
==
"softmax"
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
8
,
16
,
64
,
64
])
assert
mapping
[
'output'
].
logical_shape
==
torch
.
Size
([
8
,
16
,
64
,
64
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert
len
(
split_strategies_vector
)
==
len
(
previous_strategies_vector
)
strategy_name_list
=
[
strategy
.
name
for
strategy
in
split_strategies_vector
]
if
softmax_dim
==
0
:
assert
'[R, R, R, S1] -> [R, R, R, S1]_0'
in
strategy_name_list
assert
'[R, S0, R, S1] -> [R, S0, R, S1]_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, R, S0, S1]_2'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_3'
in
strategy_name_list
assert
'[R, S1, R, S0] -> [R, S1, R, S0]_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, R, S1, S0]_5'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, S0, R, R]_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, R, S0, R]_8'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, S1, R, R]_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, R, S1, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_17'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_18'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, S01, R, R]_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, R, S01, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, S01]_22'
in
strategy_name_list
if
softmax_dim
==
1
:
assert
'[S0, R, R, S1] -> [S0, R, R, S1]_0'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, R, S0, S1]_2'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, R, R, S0]_3'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, R, S1, S0]_5'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_6'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, R, S0, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_9'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, R, S1, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_17'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R]_18'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, R, S01, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, S01]_22'
in
strategy_name_list
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
parameterize
(
'softmax_dim'
,
[
0
,
1
,
2
,
3
])
@
parameterize
(
'model_cls'
,
[
LinearSplitModel
])
def
test_split_handler
(
softmax_dim
,
model_cls
):
world_size
=
4
run_func
=
partial
(
check_split_handler
,
softmax_dim
=
softmax_dim
,
model_cls
=
model_cls
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_split_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment