Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
81330b03
Unverified
Commit
81330b03
authored
Nov 27, 2022
by
YuliangLiu0306
Committed by
GitHub
Nov 27, 2022
Browse files
[autoparallel] add experimental permute handler (#2029)
parent
95c4532f
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
657 additions
and
37 deletions
+657
-37
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+69
-25
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+2
-1
colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
...rallel/tensor_shard/node_handler/experimental/__init__.py
+6
-2
colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py
...tensor_shard/node_handler/experimental/permute_handler.py
+76
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py
...nsor_shard/node_handler/experimental/reshape_generator.py
+94
-3
colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py
...nsor_shard/node_handler/experimental/transpose_handler.py
+65
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
...el/tensor_shard/node_handler/experimental/view_handler.py
+3
-1
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
...uto_parallel/tensor_shard/node_handler/reshape_handler.py
+0
-4
colossalai/tensor/comm_spec.py
colossalai/tensor/comm_spec.py
+2
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
...d/test_node_handler/test_permute_and_transpose_handler.py
+339
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
.../test_tensor_shard/test_node_handler/test_view_handler.py
+1
-1
No files found.
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
81330b03
...
...
@@ -37,30 +37,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
origin_node_sharding_spec_dict
[
node_index
]
=
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
))
# experimental pass for torch.Tensor.view
# Arguments of view op will be divided in the sharded dimensions.
for
node
in
nodes
:
if
node
.
op
==
'call_method'
and
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
in
(
torch
.
Tensor
.
view
,):
output_dim_partition_dict
=
node
.
sharding_spec
.
dim_partition_dict
device_mesh
=
node
.
sharding_spec
.
device_mesh
new_args
=
[]
for
arg
in
node
.
args
:
if
isinstance
(
arg
,
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
new_args
.
append
(
arg
.
_meta_data
)
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
int
),
'The argument in view node should be either type of Node or int.'
new_args
.
append
(
arg
)
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
new_args
[
dim
+
1
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
# the dict to get input sharding specs of user node
sharding_spec_convert_dict
=
{}
# the dict to record comm actions of nodes
...
...
@@ -113,7 +89,74 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
return
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
def
_module_params_sharding
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
):
def
_node_args_converting
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
):
"""
This pass will process node args to adapt the distributed tensor layout.
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
for
node
in
nodes
:
# skip the placeholder node added in _solution_annotation pass
if
not
hasattr
(
node
,
'sharding_spec'
):
continue
output_dim_partition_dict
=
node
.
sharding_spec
.
dim_partition_dict
device_mesh
=
node
.
sharding_spec
.
device_mesh
new_args
=
[]
if
node
.
op
==
'call_method'
:
method
=
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
# process the node with (input, *shape) style args
if
method
in
(
torch
.
Tensor
.
view
,
torch
.
Tensor
.
reshape
):
for
arg
in
node
.
args
:
if
isinstance
(
arg
,
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
new_args
.
append
(
arg
.
_meta_data
)
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
int
),
'The argument in view node should be either type of Node or int.'
new_args
.
append
(
arg
)
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
# we will skip the dim with -1 value
if
new_args
[
dim
+
1
]
==
-
1
:
continue
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
new_args
[
dim
+
1
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
elif
node
.
op
==
'call_function'
:
target
=
node
.
target
# process the node with (input, torch.Size) style args
if
target
in
(
torch
.
reshape
,):
for
arg
in
node
.
args
:
if
isinstance
(
arg
,
Node
):
if
isinstance
(
arg
.
_meta_data
,
(
tuple
,
list
)):
new_args
.
append
(
list
(
arg
.
_meta_data
))
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
(
tuple
,
list
)),
'The argument in reshape node should be either type of Node or tuple.'
new_args
.
append
(
list
(
arg
))
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
# we will skip the dim with -1 value
if
new_args
[
1
][
dim
]
==
-
1
:
continue
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
new_args
[
1
][
dim
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
return
gm
def
_module_params_sharding
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
...
...
@@ -216,6 +259,7 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
):
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
_solution_annotatation
(
gm
,
solution
)
gm
=
_node_args_converting
(
gm
,
device_mesh
)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
gm
=
_module_params_sharding
(
gm
,
device_mesh
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
81330b03
...
...
@@ -3,6 +3,7 @@ from .batch_norm_handler import BatchNormModuleHandler
from
.binary_elementwise_handler
import
BinaryElementwiseHandler
from
.bmm_handler
import
AddBMMFunctionHandler
,
BMMFunctionHandler
from
.conv_handler
import
ConvFunctionHandler
,
ConvModuleHandler
from
.experimental
import
PermuteHandler
,
ViewHandler
from
.getatrr_handler
import
GetattrHandler
from
.getitem_handler
import
GetItemHandler
from
.layer_norm_handler
import
LayerNormModuleHandler
...
...
@@ -21,5 +22,5 @@ __all__ = [
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'GetItemHandler'
,
'GetattrHandler'
'GetItemHandler'
,
'GetattrHandler'
,
'ViewHandler'
,
'PermuteHandler'
]
colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
View file @
81330b03
from
.view_generator
import
ViewGenerator
from
.permute_handler
import
PermuteHandler
from
.reshape_generator
import
PermuteGenerator
,
TransposeGenerator
,
ViewGenerator
from
.transpose_handler
import
TransposeHandler
from
.view_handler
import
ViewHandler
__all__
=
[
'ViewGenerator'
,
'ViewHandler'
]
__all__
=
[
'ViewGenerator'
,
'ViewHandler'
,
'PermuteGenerator'
,
'PermuteHandler'
,
'TransposeGenerator'
,
'TransposeGenerator'
]
colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py
0 → 100644
View file @
81330b03
from
typing
import
Dict
,
List
import
torch
from
...sharding_strategy
import
OperationData
,
OperationDataType
from
..node_handler
import
NodeHandler
from
..registry
import
operator_registry
from
..strategy
import
StrategyGenerator
from
.reshape_generator
import
PermuteGenerator
__all__
=
[
'PermuteHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
permute
)
@
operator_registry
.
register
(
torch
.
permute
)
class
PermuteHandler
(
NodeHandler
):
"""
A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
PermuteGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
permute_dims
=
[]
if
self
.
node
.
op
==
'call_method'
:
# torch.Tensor.permute (input, *dims)
for
arg
in
self
.
node
.
args
:
if
isinstance
(
arg
,
torch
.
fx
.
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
permute_dims
.
append
(
arg
.
_meta_data
)
else
:
assert
isinstance
(
arg
,
int
),
'The argument in permute node should be either type of Node or int.'
permute_dims
.
append
(
arg
)
else
:
# torch.permute (input, dims)
for
arg
in
self
.
node
.
args
:
if
isinstance
(
arg
,
torch
.
fx
.
Node
):
if
isinstance
(
arg
.
_meta_data
,
(
tuple
,
list
)):
permute_dims
.
extend
(
arg
.
_meta_data
)
else
:
assert
isinstance
(
arg
,
(
tuple
,
list
)),
'The argument in permute node should be type of Node, Tuple[int] or List[int].'
permute_dims
.
extend
(
arg
)
num_dims
=
self
.
node
.
_meta_data
.
dim
()
for
i
in
range
(
num_dims
):
# recover negative value to positive
if
permute_dims
[
i
]
<
0
:
permute_dims
[
i
]
+=
num_dims
physical_shape_operand
=
OperationData
(
name
=
'permute_dims'
,
type
=
OperationDataType
.
ARG
,
data
=
list
(
permute_dims
))
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"permute_dims"
:
physical_shape_operand
,
"output"
:
physical_output_operand
}
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/experimental/
view
_generator.py
→
colossalai/auto_parallel/tensor_shard/node_handler/experimental/
reshape
_generator.py
View file @
81330b03
...
...
@@ -17,12 +17,12 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.sharding_spec
import
ShardingSpec
__all__
=
[
'
View
Generator'
]
__all__
=
[
'
ReshapeGenerator'
,
'ViewGenerator'
,
'PermuteGenerator'
,
'Transpose
Generator'
]
class
View
Generator
(
FollowingStrategyGenerator
):
class
Reshape
Generator
(
FollowingStrategyGenerator
):
"""
View
Generator
which deals with
the sha
rding strategies of view op
.
Reshape
Generator
is the base class for all
the
re
sha
pe operation
.
"""
def
validate
(
self
)
->
bool
:
...
...
@@ -61,6 +61,15 @@ class ViewGenerator(FollowingStrategyGenerator):
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
return
super
().
collate_strategies
()
class
ViewGenerator
(
ReshapeGenerator
):
"""
ViewGenerator deals with the sharding strategies of view op.
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
...
...
@@ -136,3 +145,85 @@ class ViewGenerator(FollowingStrategyGenerator):
strategy_list
.
append
(
strategy
)
return
strategy_list
class
PermuteGenerator
(
ReshapeGenerator
):
"""
PermuteGenerator deals with the sharding strategies of permute op.
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
permute_dims
=
self
.
op_data
[
'permute_dims'
].
data
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_output
=
{}
for
dim_index
,
permute_dim
in
enumerate
(
permute_dims
):
if
permute_dim
in
dim_partition_dict_for_input
:
dim_partition_dict_for_output
[
dim_index
]
=
dim_partition_dict_for_input
[
permute_dim
]
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
return
strategy_list
class
TransposeGenerator
(
ReshapeGenerator
):
"""
TransposeGenerator deals with the sharding strategies of permute op.
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_output
=
{}
transpose_dims
=
self
.
op_data
[
'transpose_dims'
].
data
dim_0
=
transpose_dims
[
0
]
dim_1
=
transpose_dims
[
1
]
for
dim
,
sharded_dims
in
dim_partition_dict_for_input
.
items
():
if
dim
==
dim_0
:
dim_partition_dict_for_output
[
dim_1
]
=
dim_partition_dict_for_input
[
dim_0
]
elif
dim
==
dim_1
:
dim_partition_dict_for_output
[
dim_0
]
=
dim_partition_dict_for_input
[
dim_1
]
else
:
dim_partition_dict_for_output
[
dim
]
=
sharded_dims
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py
0 → 100644
View file @
81330b03
from
typing
import
Dict
,
List
import
torch
from
...sharding_strategy
import
OperationData
,
OperationDataType
from
..node_handler
import
NodeHandler
from
..registry
import
operator_registry
from
..strategy
import
StrategyGenerator
from
.reshape_generator
import
TransposeGenerator
__all__
=
[
'TransposeHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
transpose
)
@
operator_registry
.
register
(
torch
.
transpose
)
class
TransposeHandler
(
NodeHandler
):
"""
A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
TransposeGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
transpose_dims
=
[]
# torch.transpose (input, dim0, dim1)
for
arg
in
self
.
node
.
args
:
if
isinstance
(
arg
,
torch
.
fx
.
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
transpose_dims
.
append
(
arg
.
_meta_data
)
else
:
transpose_dims
.
append
(
arg
)
num_dims
=
self
.
node
.
_meta_data
.
dim
()
for
i
in
range
(
2
):
# recover negative value to positive
if
transpose_dims
[
i
]
<
0
:
transpose_dims
[
i
]
+=
num_dims
physical_shape_operand
=
OperationData
(
name
=
'transpose_dims'
,
type
=
OperationDataType
.
ARG
,
data
=
list
(
transpose_dims
))
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"transpose_dims"
:
physical_shape_operand
,
"output"
:
physical_output_operand
}
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
View file @
81330b03
...
...
@@ -6,11 +6,13 @@ from ...sharding_strategy import OperationData, OperationDataType
from
..node_handler
import
NodeHandler
from
..registry
import
operator_registry
from
..strategy
import
StrategyGenerator
from
.
view
_generator
import
ViewGenerator
from
.
reshape
_generator
import
ViewGenerator
__all__
=
[
'ViewHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
reshape
)
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
Tensor
.
view
)
class
ViewHandler
(
NodeHandler
):
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
View file @
81330b03
...
...
@@ -10,13 +10,9 @@ from .strategy import ReshapeGenerator, StrategyGenerator
__all__
=
[
'ReshapeHandler'
]
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
Tensor
.
split
)
@
operator_registry
.
register
(
torch
.
split
)
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
Tensor
.
transpose
)
@
operator_registry
.
register
(
torch
.
Tensor
.
permute
)
@
operator_registry
.
register
(
torch
.
Tensor
.
view
)
@
operator_registry
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
class
ReshapeHandler
(
NodeHandler
):
"""
...
...
colossalai/tensor/comm_spec.py
View file @
81330b03
...
...
@@ -23,6 +23,8 @@ def _all_gather(tensor, comm_spec):
torch
.
zeros
(
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
comm_spec
.
device_mesh
.
mesh_shape
[
comm_spec
.
logical_process_axis
])
]
# without this contiguous operation, the all gather may get some unexpected results.
tensor
=
tensor
.
contiguous
()
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
process_group
)
output
=
torch
.
cat
(
tuple
(
tensor_list
),
comm_spec
.
gather_dim
).
contiguous
()
return
output
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
0 → 100644
View file @
81330b03
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.conv_handler
import
ConvFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.experimental
import
PermuteHandler
,
TransposeHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.linear_handler
import
LinearFunctionHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
class
ConvReshapeModel
(
nn
.
Module
):
def
__init__
(
self
,
reshape_dims
,
call_function
):
super
().
__init__
()
self
.
reshape_dims
=
reshape_dims
self
.
call_function
=
call_function
def
forward
(
self
,
input
,
other
):
conv_node
=
nn
.
functional
.
conv2d
(
input
,
other
,
bias
=
None
)
# permute_node = torch.permute(conv_node, self.permute_dims)
if
self
.
call_function
==
torch
.
permute
:
permute_node
=
self
.
call_function
(
conv_node
,
self
.
reshape_dims
)
else
:
permute_node
=
self
.
call_function
(
conv_node
,
*
self
.
reshape_dims
)
return
permute_node
class
LinearReshapeModel
(
nn
.
Module
):
def
__init__
(
self
,
reshape_dims
,
call_function
):
super
().
__init__
()
self
.
reshape_dims
=
reshape_dims
self
.
call_function
=
call_function
def
forward
(
self
,
input
,
other
):
linear_node
=
nn
.
functional
.
linear
(
input
,
other
,
bias
=
None
)
# permute_node = torch.permute(linear_node, self.tgt_shape)
if
self
.
call_function
==
torch
.
permute
:
permute_node
=
self
.
call_function
(
linear_node
,
self
.
reshape_dims
)
else
:
permute_node
=
self
.
call_function
(
linear_node
,
*
self
.
reshape_dims
)
return
permute_node
def
check_view_handler
(
rank
,
call_function
,
reshape_dims
,
model_cls
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
if
call_function
==
torch
.
permute
:
reshape_dims
=
reshape_dims
[
0
]
elif
call_function
==
torch
.
transpose
:
reshape_dims
=
reshape_dims
[
1
]
model
=
model_cls
(
reshape_dims
,
call_function
).
cuda
()
if
model_cls
.
__name__
==
'ConvReshapeModel'
:
input
=
torch
.
rand
(
8
,
8
,
66
,
66
).
to
(
'cuda'
)
other
=
torch
.
rand
(
16
,
8
,
3
,
3
).
to
(
'cuda'
)
# index of conv node in computation graph
node_index
=
2
# total number of conv strategies
strategy_number
=
16
if
model_cls
.
__name__
==
'LinearReshapeModel'
:
input
=
torch
.
rand
(
8
,
16
,
64
,
32
).
to
(
'cuda'
)
other
=
torch
.
rand
(
64
,
32
).
to
(
'cuda'
)
# index of linear node in computation graph
node_index
=
2
# total number of linear strategies
strategy_number
=
23
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
[
input
,
other
],
meta_arg_names
=
[
'input'
,
'other'
],
node_type
=
'following'
)
tracer
=
ColoTracer
()
if
model_cls
.
__name__
==
'ConvReshapeModel'
:
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {})
# return permute
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
8
,
8
,
66
,
66
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
16
,
8
,
3
,
3
).
to
(
'meta'
),
})
if
model_cls
.
__name__
==
'LinearReshapeModel'
:
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
# return permute
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
8
,
16
,
64
,
32
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
64
,
32
).
to
(
'meta'
),
})
gm
=
ColoGraphModule
(
model
,
graph
)
previous_mod_node
=
list
(
graph
.
nodes
)[
2
]
reshape_node
=
list
(
graph
.
nodes
)[
3
]
view_strategies_vector
=
StrategiesVector
(
reshape_node
)
previous_strategies_vector
=
StrategiesVector
(
previous_mod_node
)
# build handler
if
model_cls
.
__name__
==
'ConvReshapeModel'
:
conv_handler
=
ConvFunctionHandler
(
node
=
previous_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
previous_strategies_vector
)
conv_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
previous_mod_node
,
'strategies_vector'
,
previous_strategies_vector
)
if
model_cls
.
__name__
==
'LinearReshapeModel'
:
assert
len
(
previous_strategies_vector
)
==
0
linear_handler
=
LinearFunctionHandler
(
node
=
previous_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
previous_strategies_vector
)
linear_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
previous_mod_node
,
'strategies_vector'
,
previous_strategies_vector
)
if
call_function
==
torch
.
permute
:
reshape_handler
=
PermuteHandler
(
node
=
reshape_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
view_strategies_vector
)
else
:
reshape_handler
=
TransposeHandler
(
node
=
reshape_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
view_strategies_vector
)
reshape_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
# check operation data mapping
mapping
=
reshape_handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
data
is
not
None
if
model_cls
.
__name__
==
'ConvReshapeModel'
:
assert
mapping
[
'input'
].
name
==
"conv2d"
else
:
assert
mapping
[
'input'
].
name
==
"linear"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
8
,
16
,
64
,
64
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
8
,
16
,
64
,
64
])
if
call_function
==
torch
.
permute
:
assert
mapping
[
'output'
].
name
==
"permute"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
permute
(
torch
.
rand
(
8
,
16
,
64
,
64
),
reshape_dims
).
shape
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
else
:
assert
mapping
[
'output'
].
name
==
"transpose"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
transpose
(
torch
.
rand
(
8
,
16
,
64
,
64
),
*
reshape_dims
).
shape
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert
len
(
view_strategies_vector
)
==
len
(
previous_strategies_vector
)
strategy_name_list
=
[
strategy
.
name
for
strategy
in
view_strategies_vector
]
if
rank
==
0
:
for
name
in
strategy_name_list
:
print
(
name
)
if
model_cls
.
__name__
==
'ConvReshapeModel'
:
if
reshape_dims
in
((
0
,
2
,
1
,
3
),
(
1
,
2
)):
assert
'[S0, S1, R, R] -> [S0, R, S1, R]_0'
in
strategy_name_list
assert
'[S1, S0, R, R] -> [S1, R, S0, R]_1'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_2'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_3'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_4'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_5'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_7'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_8'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_10'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_11'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_12'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, S01, R]_15'
in
strategy_name_list
if
reshape_dims
==
(
2
,
0
,
1
,
3
):
assert
'[S0, S1, R, R] -> [R, S0, S1, R]_0'
in
strategy_name_list
assert
'[S1, S0, R, R] -> [R, S1, S0, R]_1'
in
strategy_name_list
assert
'[S0, R, R, R] -> [R, S0, R, R]_2'
in
strategy_name_list
assert
'[S1, R, R, R] -> [R, S1, R, R]_3'
in
strategy_name_list
assert
'[S0, R, R, R] -> [R, S0, R, R]_4'
in
strategy_name_list
assert
'[S1, R, R, R] -> [R, S1, R, R]_5'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_7'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_8'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_10'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_11'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_12'
in
strategy_name_list
assert
'[S01, R, R, R] -> [R, S01, R, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, S01, R]_15'
in
strategy_name_list
if
reshape_dims
==
(
1
,
3
):
assert
'[S0, S1, R, R] -> [S0, R, R, S1]_0'
in
strategy_name_list
assert
'[S1, S0, R, R] -> [S1, R, R, S0]_1'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_2'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_3'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_4'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_5'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, R, S1]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, R, S0]_7'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_8'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, R, S0]_10'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, R, S1]_11'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_12'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, R, S01]_15'
in
strategy_name_list
if
model_cls
.
__name__
==
'LinearReshapeModel'
:
if
reshape_dims
==
((
0
,
2
,
1
,
3
),
(
1
,
2
)):
assert
'[S0, R, R, S1] -> [S0, R, R, S1]_0'
in
strategy_name_list
assert
'[R, S0, R, S1] -> [R, R, S0, S1]_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, S0, R, S1]_2'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, R, R, S0]_3'
in
strategy_name_list
assert
'[R, S1, R, S0] -> [R, R, S1, S0]_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, S1, R, S0]_5'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, S0, R, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_9'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, S1, R, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_17'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R]_18'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, S01, R]_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, S01, R, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, S01]_22'
in
strategy_name_list
if
reshape_dims
==
(
2
,
0
,
1
,
3
):
assert
'[S0, R, R, S1] -> [R, S0, R, S1]_0'
in
strategy_name_list
assert
'[R, S0, R, S1] -> [R, R, S0, S1]_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [S0, R, R, S1]_2'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [R, S1, R, S0]_3'
in
strategy_name_list
assert
'[R, S1, R, S0] -> [R, R, S1, S0]_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [S1, R, R, S0]_5'
in
strategy_name_list
assert
'[S0, R, R, R] -> [R, S0, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, S0, R]_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [S0, R, R, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R] -> [R, S1, R, R]_9'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, S1, R]_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [S1, R, R, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1]_17'
in
strategy_name_list
assert
'[S01, R, R, R] -> [R, S01, R, R]_18'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, S01, R]_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [S01, R, R, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, S01]_22'
in
strategy_name_list
if
reshape_dims
==
(
1
,
3
):
assert
'[S0, R, R, S1] -> [S0, S1, R, R]_0'
in
strategy_name_list
assert
'[R, S0, R, S1] -> [R, S1, R, S0]_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, S1, S0, R]_2'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, S0, R, R]_3'
in
strategy_name_list
assert
'[R, S1, R, S0] -> [R, S0, R, S1]_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, S0, S1, R]_5'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, R, R, S0]_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, R, S0, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R]_9'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, R, R, S1]_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, R, S1, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, S1, R, R]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, S0, R, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, S0, R, R]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, S1, R, R]_17'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R]_18'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, R, R, S01]_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, R, S01, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, S01, R, R]_22'
in
strategy_name_list
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
parameterize
(
'call_function'
,
[
torch
.
permute
,
torch
.
transpose
])
@
parameterize
(
'reshape_dims'
,
[((
0
,
2
,
1
,
3
),
(
1
,
2
)),
((
2
,
0
,
1
,
3
),
(
1
,
3
))])
@
parameterize
(
'model_cls'
,
[
ConvReshapeModel
,
LinearReshapeModel
])
def
test_view_handler
(
call_function
,
reshape_dims
,
model_cls
):
world_size
=
4
run_func
=
partial
(
check_view_handler
,
call_function
=
call_function
,
reshape_dims
=
reshape_dims
,
model_cls
=
model_cls
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_view_handler
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
View file @
81330b03
...
...
@@ -84,7 +84,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
# return view
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
8
,
16
,
66
,
66
).
to
(
'meta'
),
"input"
:
torch
.
rand
(
8
,
8
,
66
,
66
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
16
,
8
,
3
,
3
).
to
(
'meta'
),
})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment