Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
81330b03
Unverified
Commit
81330b03
authored
Nov 27, 2022
by
YuliangLiu0306
Committed by
GitHub
Nov 27, 2022
Browse files
[autoparallel] add experimental permute handler (#2029)
parent
95c4532f
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
657 additions
and
37 deletions
+657
-37
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+69
-25
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+2
-1
colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
...rallel/tensor_shard/node_handler/experimental/__init__.py
+6
-2
colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py
...tensor_shard/node_handler/experimental/permute_handler.py
+76
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py
...nsor_shard/node_handler/experimental/reshape_generator.py
+94
-3
colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py
...nsor_shard/node_handler/experimental/transpose_handler.py
+65
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
...el/tensor_shard/node_handler/experimental/view_handler.py
+3
-1
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
...uto_parallel/tensor_shard/node_handler/reshape_handler.py
+0
-4
colossalai/tensor/comm_spec.py
colossalai/tensor/comm_spec.py
+2
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
...d/test_node_handler/test_permute_and_transpose_handler.py
+339
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
.../test_tensor_shard/test_node_handler/test_view_handler.py
+1
-1
No files found.
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
81330b03
...
@@ -37,30 +37,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
...
@@ -37,30 +37,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
origin_node_sharding_spec_dict
[
node_index
]
=
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
origin_node_sharding_spec_dict
[
node_index
]
=
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
))
str
(
node
))
# experimental pass for torch.Tensor.view
# Arguments of view op will be divided in the sharded dimensions.
for
node
in
nodes
:
if
node
.
op
==
'call_method'
and
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
in
(
torch
.
Tensor
.
view
,):
output_dim_partition_dict
=
node
.
sharding_spec
.
dim_partition_dict
device_mesh
=
node
.
sharding_spec
.
device_mesh
new_args
=
[]
for
arg
in
node
.
args
:
if
isinstance
(
arg
,
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
new_args
.
append
(
arg
.
_meta_data
)
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
int
),
'The argument in view node should be either type of Node or int.'
new_args
.
append
(
arg
)
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
new_args
[
dim
+
1
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
# the dict to get input sharding specs of user node
# the dict to get input sharding specs of user node
sharding_spec_convert_dict
=
{}
sharding_spec_convert_dict
=
{}
# the dict to record comm actions of nodes
# the dict to record comm actions of nodes
...
@@ -113,7 +89,74 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
...
@@ -113,7 +89,74 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
return
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
return
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
def
_module_params_sharding
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
):
def
_node_args_converting
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
):
"""
This pass will process node args to adapt the distributed tensor layout.
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
for
node
in
nodes
:
# skip the placeholder node added in _solution_annotation pass
if
not
hasattr
(
node
,
'sharding_spec'
):
continue
output_dim_partition_dict
=
node
.
sharding_spec
.
dim_partition_dict
device_mesh
=
node
.
sharding_spec
.
device_mesh
new_args
=
[]
if
node
.
op
==
'call_method'
:
method
=
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
# process the node with (input, *shape) style args
if
method
in
(
torch
.
Tensor
.
view
,
torch
.
Tensor
.
reshape
):
for
arg
in
node
.
args
:
if
isinstance
(
arg
,
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
new_args
.
append
(
arg
.
_meta_data
)
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
int
),
'The argument in view node should be either type of Node or int.'
new_args
.
append
(
arg
)
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
# we will skip the dim with -1 value
if
new_args
[
dim
+
1
]
==
-
1
:
continue
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
new_args
[
dim
+
1
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
elif
node
.
op
==
'call_function'
:
target
=
node
.
target
# process the node with (input, torch.Size) style args
if
target
in
(
torch
.
reshape
,):
for
arg
in
node
.
args
:
if
isinstance
(
arg
,
Node
):
if
isinstance
(
arg
.
_meta_data
,
(
tuple
,
list
)):
new_args
.
append
(
list
(
arg
.
_meta_data
))
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
(
tuple
,
list
)),
'The argument in reshape node should be either type of Node or tuple.'
new_args
.
append
(
list
(
arg
))
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
# we will skip the dim with -1 value
if
new_args
[
1
][
dim
]
==
-
1
:
continue
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
new_args
[
1
][
dim
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
return
gm
def
_module_params_sharding
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
):
"""
"""
Apply the sharding action to the module parameters and buffers following the
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
instructions of solver solution.
...
@@ -216,6 +259,7 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
...
@@ -216,6 +259,7 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
):
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
):
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
_solution_annotatation
(
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
_solution_annotatation
(
gm
,
solution
)
gm
,
solution
)
gm
=
_node_args_converting
(
gm
,
device_mesh
)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
# gm = implicit_comm_action_apply(gm)
gm
=
_module_params_sharding
(
gm
,
device_mesh
)
gm
=
_module_params_sharding
(
gm
,
device_mesh
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
81330b03
...
@@ -3,6 +3,7 @@ from .batch_norm_handler import BatchNormModuleHandler
...
@@ -3,6 +3,7 @@ from .batch_norm_handler import BatchNormModuleHandler
from
.binary_elementwise_handler
import
BinaryElementwiseHandler
from
.binary_elementwise_handler
import
BinaryElementwiseHandler
from
.bmm_handler
import
AddBMMFunctionHandler
,
BMMFunctionHandler
from
.bmm_handler
import
AddBMMFunctionHandler
,
BMMFunctionHandler
from
.conv_handler
import
ConvFunctionHandler
,
ConvModuleHandler
from
.conv_handler
import
ConvFunctionHandler
,
ConvModuleHandler
from
.experimental
import
PermuteHandler
,
ViewHandler
from
.getatrr_handler
import
GetattrHandler
from
.getatrr_handler
import
GetattrHandler
from
.getitem_handler
import
GetItemHandler
from
.getitem_handler
import
GetItemHandler
from
.layer_norm_handler
import
LayerNormModuleHandler
from
.layer_norm_handler
import
LayerNormModuleHandler
...
@@ -21,5 +22,5 @@ __all__ = [
...
@@ -21,5 +22,5 @@ __all__ = [
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'GetItemHandler'
,
'GetattrHandler'
'GetItemHandler'
,
'GetattrHandler'
,
'ViewHandler'
,
'PermuteHandler'
]
]
colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
View file @
81330b03
from
.view_generator
import
ViewGenerator
from
.permute_handler
import
PermuteHandler
from
.reshape_generator
import
PermuteGenerator
,
TransposeGenerator
,
ViewGenerator
from
.transpose_handler
import
TransposeHandler
from
.view_handler
import
ViewHandler
from
.view_handler
import
ViewHandler
__all__
=
[
'ViewGenerator'
,
'ViewHandler'
]
__all__
=
[
'ViewGenerator'
,
'ViewHandler'
,
'PermuteGenerator'
,
'PermuteHandler'
,
'TransposeGenerator'
,
'TransposeGenerator'
]
colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py
0 → 100644
View file @
81330b03
from
typing
import
Dict
,
List
import
torch
from
...sharding_strategy
import
OperationData
,
OperationDataType
from
..node_handler
import
NodeHandler
from
..registry
import
operator_registry
from
..strategy
import
StrategyGenerator
from
.reshape_generator
import
PermuteGenerator
__all__
=
[
'PermuteHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
permute
)
@
operator_registry
.
register
(
torch
.
permute
)
class
PermuteHandler
(
NodeHandler
):
"""
A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
PermuteGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
permute_dims
=
[]
if
self
.
node
.
op
==
'call_method'
:
# torch.Tensor.permute (input, *dims)
for
arg
in
self
.
node
.
args
:
if
isinstance
(
arg
,
torch
.
fx
.
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
permute_dims
.
append
(
arg
.
_meta_data
)
else
:
assert
isinstance
(
arg
,
int
),
'The argument in permute node should be either type of Node or int.'
permute_dims
.
append
(
arg
)
else
:
# torch.permute (input, dims)
for
arg
in
self
.
node
.
args
:
if
isinstance
(
arg
,
torch
.
fx
.
Node
):
if
isinstance
(
arg
.
_meta_data
,
(
tuple
,
list
)):
permute_dims
.
extend
(
arg
.
_meta_data
)
else
:
assert
isinstance
(
arg
,
(
tuple
,
list
)),
'The argument in permute node should be type of Node, Tuple[int] or List[int].'
permute_dims
.
extend
(
arg
)
num_dims
=
self
.
node
.
_meta_data
.
dim
()
for
i
in
range
(
num_dims
):
# recover negative value to positive
if
permute_dims
[
i
]
<
0
:
permute_dims
[
i
]
+=
num_dims
physical_shape_operand
=
OperationData
(
name
=
'permute_dims'
,
type
=
OperationDataType
.
ARG
,
data
=
list
(
permute_dims
))
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"permute_dims"
:
physical_shape_operand
,
"output"
:
physical_output_operand
}
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/experimental/
view
_generator.py
→
colossalai/auto_parallel/tensor_shard/node_handler/experimental/
reshape
_generator.py
View file @
81330b03
...
@@ -17,12 +17,12 @@ from colossalai.auto_parallel.tensor_shard.utils import (
...
@@ -17,12 +17,12 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
__all__
=
[
'
View
Generator'
]
__all__
=
[
'
ReshapeGenerator'
,
'ViewGenerator'
,
'PermuteGenerator'
,
'Transpose
Generator'
]
class
View
Generator
(
FollowingStrategyGenerator
):
class
Reshape
Generator
(
FollowingStrategyGenerator
):
"""
"""
View
Generator
which deals with
the sha
rding strategies of view op
.
Reshape
Generator
is the base class for all
the
re
sha
pe operation
.
"""
"""
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
...
@@ -61,6 +61,15 @@ class ViewGenerator(FollowingStrategyGenerator):
...
@@ -61,6 +61,15 @@ class ViewGenerator(FollowingStrategyGenerator):
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
return
super
().
collate_strategies
()
class
ViewGenerator
(
ReshapeGenerator
):
"""
ViewGenerator deals with the sharding strategies of view op.
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
...
@@ -136,3 +145,85 @@ class ViewGenerator(FollowingStrategyGenerator):
...
@@ -136,3 +145,85 @@ class ViewGenerator(FollowingStrategyGenerator):
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
return
strategy_list
return
strategy_list
class
PermuteGenerator
(
ReshapeGenerator
):
"""
PermuteGenerator deals with the sharding strategies of permute op.
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
permute_dims
=
self
.
op_data
[
'permute_dims'
].
data
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_output
=
{}
for
dim_index
,
permute_dim
in
enumerate
(
permute_dims
):
if
permute_dim
in
dim_partition_dict_for_input
:
dim_partition_dict_for_output
[
dim_index
]
=
dim_partition_dict_for_input
[
permute_dim
]
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
return
strategy_list
class
TransposeGenerator
(
ReshapeGenerator
):
"""
TransposeGenerator deals with the sharding strategies of permute op.
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_output
=
{}
transpose_dims
=
self
.
op_data
[
'transpose_dims'
].
data
dim_0
=
transpose_dims
[
0
]
dim_1
=
transpose_dims
[
1
]
for
dim
,
sharded_dims
in
dim_partition_dict_for_input
.
items
():
if
dim
==
dim_0
:
dim_partition_dict_for_output
[
dim_1
]
=
dim_partition_dict_for_input
[
dim_0
]
elif
dim
==
dim_1
:
dim_partition_dict_for_output
[
dim_0
]
=
dim_partition_dict_for_input
[
dim_1
]
else
:
dim_partition_dict_for_output
[
dim
]
=
sharded_dims
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py
0 → 100644
View file @
81330b03
from
typing
import
Dict
,
List
import
torch
from
...sharding_strategy
import
OperationData
,
OperationDataType
from
..node_handler
import
NodeHandler
from
..registry
import
operator_registry
from
..strategy
import
StrategyGenerator
from
.reshape_generator
import
TransposeGenerator
__all__
=
[
'TransposeHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
transpose
)
@
operator_registry
.
register
(
torch
.
transpose
)
class
TransposeHandler
(
NodeHandler
):
"""
A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
TransposeGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
transpose_dims
=
[]
# torch.transpose (input, dim0, dim1)
for
arg
in
self
.
node
.
args
:
if
isinstance
(
arg
,
torch
.
fx
.
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
transpose_dims
.
append
(
arg
.
_meta_data
)
else
:
transpose_dims
.
append
(
arg
)
num_dims
=
self
.
node
.
_meta_data
.
dim
()
for
i
in
range
(
2
):
# recover negative value to positive
if
transpose_dims
[
i
]
<
0
:
transpose_dims
[
i
]
+=
num_dims
physical_shape_operand
=
OperationData
(
name
=
'transpose_dims'
,
type
=
OperationDataType
.
ARG
,
data
=
list
(
transpose_dims
))
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"transpose_dims"
:
physical_shape_operand
,
"output"
:
physical_output_operand
}
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
View file @
81330b03
...
@@ -6,11 +6,13 @@ from ...sharding_strategy import OperationData, OperationDataType
...
@@ -6,11 +6,13 @@ from ...sharding_strategy import OperationData, OperationDataType
from
..node_handler
import
NodeHandler
from
..node_handler
import
NodeHandler
from
..registry
import
operator_registry
from
..registry
import
operator_registry
from
..strategy
import
StrategyGenerator
from
..strategy
import
StrategyGenerator
from
.
view
_generator
import
ViewGenerator
from
.
reshape
_generator
import
ViewGenerator
__all__
=
[
'ViewHandler'
]
__all__
=
[
'ViewHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
reshape
)
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
Tensor
.
view
)
@
operator_registry
.
register
(
torch
.
Tensor
.
view
)
class
ViewHandler
(
NodeHandler
):
class
ViewHandler
(
NodeHandler
):
"""
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
View file @
81330b03
...
@@ -10,13 +10,9 @@ from .strategy import ReshapeGenerator, StrategyGenerator
...
@@ -10,13 +10,9 @@ from .strategy import ReshapeGenerator, StrategyGenerator
__all__
=
[
'ReshapeHandler'
]
__all__
=
[
'ReshapeHandler'
]
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
Tensor
.
split
)
@
operator_registry
.
register
(
torch
.
Tensor
.
split
)
@
operator_registry
.
register
(
torch
.
split
)
@
operator_registry
.
register
(
torch
.
split
)
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
Tensor
.
transpose
)
@
operator_registry
.
register
(
torch
.
Tensor
.
permute
)
@
operator_registry
.
register
(
torch
.
Tensor
.
view
)
@
operator_registry
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
@
operator_registry
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
class
ReshapeHandler
(
NodeHandler
):
class
ReshapeHandler
(
NodeHandler
):
"""
"""
...
...
colossalai/tensor/comm_spec.py
View file @
81330b03
...
@@ -23,6 +23,8 @@ def _all_gather(tensor, comm_spec):
...
@@ -23,6 +23,8 @@ def _all_gather(tensor, comm_spec):
torch
.
zeros
(
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
torch
.
zeros
(
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
comm_spec
.
device_mesh
.
mesh_shape
[
comm_spec
.
logical_process_axis
])
for
_
in
range
(
comm_spec
.
device_mesh
.
mesh_shape
[
comm_spec
.
logical_process_axis
])
]
]
# without this contiguous operation, the all gather may get some unexpected results.
tensor
=
tensor
.
contiguous
()
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
process_group
)
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
process_group
)
output
=
torch
.
cat
(
tuple
(
tensor_list
),
comm_spec
.
gather_dim
).
contiguous
()
output
=
torch
.
cat
(
tuple
(
tensor_list
),
comm_spec
.
gather_dim
).
contiguous
()
return
output
return
output
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
0 → 100644
View file @
81330b03
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.conv_handler
import
ConvFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.experimental
import
PermuteHandler
,
TransposeHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.linear_handler
import
LinearFunctionHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
class
ConvReshapeModel
(
nn
.
Module
):
def
__init__
(
self
,
reshape_dims
,
call_function
):
super
().
__init__
()
self
.
reshape_dims
=
reshape_dims
self
.
call_function
=
call_function
def
forward
(
self
,
input
,
other
):
conv_node
=
nn
.
functional
.
conv2d
(
input
,
other
,
bias
=
None
)
# permute_node = torch.permute(conv_node, self.permute_dims)
if
self
.
call_function
==
torch
.
permute
:
permute_node
=
self
.
call_function
(
conv_node
,
self
.
reshape_dims
)
else
:
permute_node
=
self
.
call_function
(
conv_node
,
*
self
.
reshape_dims
)
return
permute_node
class
LinearReshapeModel
(
nn
.
Module
):
def
__init__
(
self
,
reshape_dims
,
call_function
):
super
().
__init__
()
self
.
reshape_dims
=
reshape_dims
self
.
call_function
=
call_function
def
forward
(
self
,
input
,
other
):
linear_node
=
nn
.
functional
.
linear
(
input
,
other
,
bias
=
None
)
# permute_node = torch.permute(linear_node, self.tgt_shape)
if
self
.
call_function
==
torch
.
permute
:
permute_node
=
self
.
call_function
(
linear_node
,
self
.
reshape_dims
)
else
:
permute_node
=
self
.
call_function
(
linear_node
,
*
self
.
reshape_dims
)
return
permute_node
def
check_view_handler
(
rank
,
call_function
,
reshape_dims
,
model_cls
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
if
call_function
==
torch
.
permute
:
reshape_dims
=
reshape_dims
[
0
]
elif
call_function
==
torch
.
transpose
:
reshape_dims
=
reshape_dims
[
1
]
model
=
model_cls
(
reshape_dims
,
call_function
).
cuda
()
if
model_cls
.
__name__
==
'ConvReshapeModel'
:
input
=
torch
.
rand
(
8
,
8
,
66
,
66
).
to
(
'cuda'
)
other
=
torch
.
rand
(
16
,
8
,
3
,
3
).
to
(
'cuda'
)
# index of conv node in computation graph
node_index
=
2
# total number of conv strategies
strategy_number
=
16
if
model_cls
.
__name__
==
'LinearReshapeModel'
:
input
=
torch
.
rand
(
8
,
16
,
64
,
32
).
to
(
'cuda'
)
other
=
torch
.
rand
(
64
,
32
).
to
(
'cuda'
)
# index of linear node in computation graph
node_index
=
2
# total number of linear strategies
strategy_number
=
23
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
[
input
,
other
],
meta_arg_names
=
[
'input'
,
'other'
],
node_type
=
'following'
)
tracer
=
ColoTracer
()
if
model_cls
.
__name__
==
'ConvReshapeModel'
:
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {})
# return permute
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
8
,
8
,
66
,
66
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
16
,
8
,
3
,
3
).
to
(
'meta'
),
})
if
model_cls
.
__name__
==
'LinearReshapeModel'
:
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
# return permute
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
8
,
16
,
64
,
32
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
64
,
32
).
to
(
'meta'
),
})
gm
=
ColoGraphModule
(
model
,
graph
)
previous_mod_node
=
list
(
graph
.
nodes
)[
2
]
reshape_node
=
list
(
graph
.
nodes
)[
3
]
view_strategies_vector
=
StrategiesVector
(
reshape_node
)
previous_strategies_vector
=
StrategiesVector
(
previous_mod_node
)
# build handler
if
model_cls
.
__name__
==
'ConvReshapeModel'
:
conv_handler
=
ConvFunctionHandler
(
node
=
previous_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
previous_strategies_vector
)
conv_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
previous_mod_node
,
'strategies_vector'
,
previous_strategies_vector
)
if
model_cls
.
__name__
==
'LinearReshapeModel'
:
assert
len
(
previous_strategies_vector
)
==
0
linear_handler
=
LinearFunctionHandler
(
node
=
previous_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
previous_strategies_vector
)
linear_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
previous_mod_node
,
'strategies_vector'
,
previous_strategies_vector
)
if
call_function
==
torch
.
permute
:
reshape_handler
=
PermuteHandler
(
node
=
reshape_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
view_strategies_vector
)
else
:
reshape_handler
=
TransposeHandler
(
node
=
reshape_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
view_strategies_vector
)
reshape_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
# check operation data mapping
mapping
=
reshape_handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
data
is
not
None
if
model_cls
.
__name__
==
'ConvReshapeModel'
:
assert
mapping
[
'input'
].
name
==
"conv2d"
else
:
assert
mapping
[
'input'
].
name
==
"linear"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
8
,
16
,
64
,
64
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
8
,
16
,
64
,
64
])
if
call_function
==
torch
.
permute
:
assert
mapping
[
'output'
].
name
==
"permute"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
permute
(
torch
.
rand
(
8
,
16
,
64
,
64
),
reshape_dims
).
shape
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
else
:
assert
mapping
[
'output'
].
name
==
"transpose"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
transpose
(
torch
.
rand
(
8
,
16
,
64
,
64
),
*
reshape_dims
).
shape
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert
len
(
view_strategies_vector
)
==
len
(
previous_strategies_vector
)
strategy_name_list
=
[
strategy
.
name
for
strategy
in
view_strategies_vector
]
if
rank
==
0
:
for
name
in
strategy_name_list
:
print
(
name
)
if
model_cls
.
__name__
==
'ConvReshapeModel'
:
if
reshape_dims
in
((
0
,
2
,
1
,
3
),
(
1
,
2
)):
assert
'[S0, S1, R, R] -> [S0, R, S1, R]_0'
in
strategy_name_list
assert
'[S1, S0, R, R] -> [S1, R, S0, R]_1'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_2'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_3'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_4'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_5'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_7'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_8'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_10'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_11'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_12'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, S01, R]_15'
in
strategy_name_list
if
reshape_dims
==
(
2
,
0
,
1
,
3
):
assert
'[S0, S1, R, R] -> [R, S0, S1, R]_0'
in
strategy_name_list
assert
'[S1, S0, R, R] -> [R, S1, S0, R]_1'
in
strategy_name_list
assert
'[S0, R, R, R] -> [R, S0, R, R]_2'
in
strategy_name_list
assert
'[S1, R, R, R] -> [R, S1, R, R]_3'
in
strategy_name_list
assert
'[S0, R, R, R] -> [R, S0, R, R]_4'
in
strategy_name_list
assert
'[S1, R, R, R] -> [R, S1, R, R]_5'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_7'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_8'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_10'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_11'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_12'
in
strategy_name_list
assert
'[S01, R, R, R] -> [R, S01, R, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, S01, R]_15'
in
strategy_name_list
if
reshape_dims
==
(
1
,
3
):
assert
'[S0, S1, R, R] -> [S0, R, R, S1]_0'
in
strategy_name_list
assert
'[S1, S0, R, R] -> [S1, R, R, S0]_1'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_2'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_3'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_4'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_5'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, R, S1]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, R, S0]_7'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_8'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, R, S0]_10'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, R, S1]_11'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_12'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, R, S01]_15'
in
strategy_name_list
if
model_cls
.
__name__
==
'LinearReshapeModel'
:
if
reshape_dims
==
((
0
,
2
,
1
,
3
),
(
1
,
2
)):
assert
'[S0, R, R, S1] -> [S0, R, R, S1]_0'
in
strategy_name_list
assert
'[R, S0, R, S1] -> [R, R, S0, S1]_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, S0, R, S1]_2'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, R, R, S0]_3'
in
strategy_name_list
assert
'[R, S1, R, S0] -> [R, R, S1, S0]_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, S1, R, S0]_5'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, S0, R, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_9'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, S1, R, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_17'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R]_18'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, S01, R]_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, S01, R, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, S01]_22'
in
strategy_name_list
if
reshape_dims
==
(
2
,
0
,
1
,
3
):
assert
'[S0, R, R, S1] -> [R, S0, R, S1]_0'
in
strategy_name_list
assert
'[R, S0, R, S1] -> [R, R, S0, S1]_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [S0, R, R, S1]_2'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [R, S1, R, S0]_3'
in
strategy_name_list
assert
'[R, S1, R, S0] -> [R, R, S1, S0]_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [S1, R, R, S0]_5'
in
strategy_name_list
assert
'[S0, R, R, R] -> [R, S0, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [S0, R, R, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R] -> [R, S1, R, R]_9'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [S1, R, R, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_17'
in
strategy_name_list
assert
'[S01, R, R, R] -> [R, S01, R, R]_18'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, S01, R]_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [S01, R, R, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, S01]_22'
in
strategy_name_list
if
reshape_dims
==
(
1
,
3
):
assert
'[S0, R, R, S1] -> [S0, S1, R, R]_0'
in
strategy_name_list
assert
'[R, S0, R, S1] -> [R, S1, R, S0]_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, S1, S0, R]_2'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, S0, R, R]_3'
in
strategy_name_list
assert
'[R, S1, R, S0] -> [R, S0, R, S1]_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, S0, S1, R]_5'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, R, S0]_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, R, S0, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_9'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, R, S1]_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, R, S1, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, S1, R, R]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, S0, R, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, S0, R, R]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, S1, R, R]_17'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R]_18'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, R, S01]_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, R, S01, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, S01, R, R]_22'
in
strategy_name_list
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
parameterize
(
'call_function'
,
[
torch
.
permute
,
torch
.
transpose
])
@
parameterize
(
'reshape_dims'
,
[((
0
,
2
,
1
,
3
),
(
1
,
2
)),
((
2
,
0
,
1
,
3
),
(
1
,
3
))])
@
parameterize
(
'model_cls'
,
[
ConvReshapeModel
,
LinearReshapeModel
])
def
test_view_handler
(
call_function
,
reshape_dims
,
model_cls
):
world_size
=
4
run_func
=
partial
(
check_view_handler
,
call_function
=
call_function
,
reshape_dims
=
reshape_dims
,
model_cls
=
model_cls
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_view_handler
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
View file @
81330b03
...
@@ -84,7 +84,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
...
@@ -84,7 +84,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
# return view
# return view
graph
=
tracer
.
trace
(
model
,
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
meta_args
=
{
"input"
:
torch
.
rand
(
8
,
16
,
66
,
66
).
to
(
'meta'
),
"input"
:
torch
.
rand
(
8
,
8
,
66
,
66
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
16
,
8
,
3
,
3
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
16
,
8
,
3
,
3
).
to
(
'meta'
),
})
})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment