Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f3f19a5c
Unverified
Commit
f3f19a5c
authored
Nov 01, 2022
by
Frank Lee
Committed by
GitHub
Nov 01, 2022
Browse files
[autoparallel] added matmul handler (#1763)
* [autoparallel] added matmul handler * polish code
parent
4df01949
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
725 additions
and
28 deletions
+725
-28
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+2
-1
colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
...auto_parallel/tensor_shard/node_handler/matmul_handler.py
+482
-0
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
..._shard/node_handler/strategy/matmul_strategy_generator.py
+40
-10
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
.../tensor_shard/node_handler/strategy/strategy_generator.py
+6
-1
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
+26
-15
colossalai/tensor/sharding_spec.py
colossalai/tensor/sharding_spec.py
+3
-1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py
...est_tensor_shard/test_node_handler/test_matmul_handler.py
+166
-0
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
f3f19a5c
...
...
@@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from
.conv_handler
import
ConvFunctionHandler
,
ConvModuleHandler
from
.layer_norm_handler
import
LayerNormModuleHandler
from
.linear_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
.matmul_handler
import
MatMulHandler
from
.normal_pooling_handler
import
NormPoolingHandler
from
.output_handler
import
OuputHandler
from
.placeholder_handler
import
PlacehodlerHandler
...
...
@@ -16,5 +17,5 @@ __all__ = [
'LinearFunctionHandler'
,
'LinearModuleHandler'
,
'BMMFunctionHandler'
,
'AddBMMFunctionHandler'
,
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'operator_registry'
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
]
colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
0 → 100644
View file @
f3f19a5c
This diff is collapsed.
Click to expand it.
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
View file @
f3f19a5c
...
...
@@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'input'
]].
get_sharded_shape_per_device
()
fwd_compute_cost
=
sharded_input_shape
[
0
]
bwd_compute_cost
=
sharded_input_shape
*
2
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
return
compute_cost
@
ignore_sharding_exception
def
no_split
(
self
):
name
=
f
'R = R dot R'
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{},
'bias'
:
{}}
...
...
@@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
@
ignore_sharding_exception
def
split_one_dim
(
self
,
mesh_dim
):
name
=
f
'R = S
{
mesh_dim
}
dot S
{
mesh_dim
}
'
...
...
@@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
# do not split dimensions for dot product
...
...
@@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
def
validate
(
self
)
->
bool
:
input_op_data
=
self
.
op_data
[
'input'
]
other_op_data
=
self
.
op_data
[
'other'
]
assert
input_op_data
.
data
.
dim
()
>
1
and
other_op_data
.
data
.
dim
()
==
1
assert
input_op_data
.
data
.
dim
()
==
2
and
other_op_data
.
data
.
dim
()
==
1
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'input'
]].
get_sharded_shape_per_device
()
fwd_compute_cost
=
sharded_input_shape
[
0
]
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
return
compute_cost
@
ignore_sharding_exception
def
no_split
(
self
):
name
=
"R = R x R"
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{},
"bias"
:
{}}
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{}}
if
self
.
has_bias
:
dim_partition_dict
[
'bias'
]
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
{})
@
ignore_sharding_exception
def
split_input_batch
(
self
,
mesh_dim
):
name
=
f
'S
{
mesh_dim
}
R = S
{
mesh_dim
}
R x R'
# get sharding spec
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{},
"output"
:
{
0
:
[
mesh_dim
]},
"bias"
:
{}}
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]
},
"other"
:
{},
"output"
:
{
0
:
[
mesh_dim
]
},
}
if
self
.
has_bias
:
dim_partition_dict
[
'bias'
]
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication action
communication_action_mapping
=
{}
if
self
.
is_param
(
'other'
):
other_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'other'
],
...
...
@@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
communication_action_mapping
[
'other'
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
is_param
(
'bias'
):
bias_comm_action
=
self
.
get_communication_action
(
...
...
@@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
2
)
communication_action_mapping
=
{
'other'
:
other_comm_action
,
'bias'
:
bias_comm_action
}
communication_action_mapping
[
'bias'
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
# no split
...
...
@@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def
validate
(
self
)
->
bool
:
input_op_data
=
self
.
op_data
[
'input'
]
other_op_data
=
self
.
op_data
[
'other'
]
assert
input_op_data
.
data
.
dim
(
)
==
3
or
other_op_data
.
data
.
dim
(
)
==
3
assert
len
(
input_op_data
.
logical_shape
)
==
3
or
len
(
other_op_data
.
logical_shape
)
==
3
if
'bias'
in
self
.
op_data
:
bias_op_data
=
self
.
op_data
[
'bias'
]
...
...
@@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
2
:
[
mesh_dim_1
]
},
"other"
:
{
0
:
[
mesh_dim_0
],
-
2
:
[
mesh_dim_1
]
1
:
[
mesh_dim_1
]
},
"bias"
:
{},
"output"
:
{
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
View file @
f3f19a5c
...
...
@@ -186,9 +186,14 @@ class StrategyGenerator(ABC):
"""
op_data
=
self
.
op_data
[
key
]
sharded_shape
=
strategy
.
sharding_specs
[
op_data
].
get_sharded_shape_per_device
()
if
len
(
sharded_shape
)
==
0
:
num_elements
=
1
else
:
num_elements
=
reduce
(
operator
.
mul
,
sharded_shape
)
dtype
=
self
.
op_data
[
key
].
data
.
dtype
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
reduce
(
operator
.
mul
,
sharded_shape
)
*
size_per_elem_bytes
return
num_elements
*
size_per_elem_bytes
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
"""
...
...
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
View file @
f3f19a5c
...
...
@@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
return
dims
[::
-
1
]
def
recover_sharding_spec_for_broadcast_shape
(
logical_sharding_spec
:
ShardingSpec
,
logical_shape
:
torch
.
Size
,
physical_shape
:
torch
.
Size
)
->
ShardingSpec
:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if
list
(
logical_shape
)
==
list
(
physical_shape
):
return
logical_sharding_spec
def
get_broadcast_dim_info
(
logical_shape
,
physical_shape
):
# get the number of dimensions
logical_num_dims
=
len
(
logical_shape
)
physical_num_dims
=
len
(
physical_shape
)
...
...
@@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
else
:
logical_dim_broadcast_info
[
logical_dim_idx
]
=
BroadcastType
.
PADDDING
return
logical_dim_broadcast_info
def
recover_sharding_spec_for_broadcast_shape
(
logical_sharding_spec
:
ShardingSpec
,
logical_shape
:
torch
.
Size
,
physical_shape
:
torch
.
Size
)
->
ShardingSpec
:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if
list
(
logical_shape
)
==
list
(
physical_shape
):
return
logical_sharding_spec
# get the number of dimensions
logical_num_dims
=
len
(
logical_shape
)
physical_num_dims
=
len
(
physical_shape
)
# get the broadcast info
logical_dim_broadcast_info
=
get_broadcast_dim_info
(
logical_shape
,
physical_shape
)
# generate the sharding spec for the physical shape
physical_dim_partition
=
{}
logical_dim_partition
=
logical_sharding_spec
.
dim_partition_dict
...
...
colossalai/tensor/sharding_spec.py
View file @
f3f19a5c
import
operator
from
copy
import
deepcopy
from
enum
import
Enum
from
functools
import
reduce
import
torch
...
...
@@ -175,6 +174,9 @@ class ShardingSpec:
dim_partition_dict
=
None
,
sharding_sequence
=
None
):
self
.
device_mesh
=
device_mesh
if
isinstance
(
entire_shape
,
(
list
,
tuple
)):
entire_shape
=
torch
.
Size
(
entire_shape
)
self
.
entire_shape
=
entire_shape
self
.
dim_partition_dict
=
dim_partition_dict
self
.
sharding_sequence
=
sharding_sequence
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py
0 → 100644
View file @
f3f19a5c
import
torch
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler
import
(
MatMulHandler
,
MatMulType
,
_get_bmm_logical_shape
,
get_matmul_type
,
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing.utils
import
parameterize
class
MatMulModule
(
nn
.
Module
):
def
forward
(
self
,
x1
,
x2
):
return
torch
.
matmul
(
x1
,
x2
)
@
parameterize
(
'tensor_shapes'
,
[
[[
8
],
[
8
]],
# dot product
[[
4
,
8
],
[
8
]],
# mat-vec product
[[
4
,
8
],
[
8
,
16
]],
# mat-mat product
[[
8
],
[
8
,
16
]],
# mat-mat product
[[
8
],
[
4
,
8
,
16
]],
# batched mat-mat product with padding + broadcasting
[[
4
,
8
,
16
],
[
16
]],
# batched mat-mat product with padding + broadcasting
[[
4
,
8
,
16
],
[
16
,
32
]],
# batched mat-mat product with broadcasting
[[
4
,
8
,
16
],
[
1
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
4
,
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
1
,
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
1
,
4
,
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
2
,
1
,
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
2
,
4
,
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product without broadcasting
])
def
test_matmul_node_handler
(
tensor_shapes
):
input_shape
,
other_shape
=
tensor_shapes
# get output shape
x1
=
torch
.
rand
(
*
input_shape
)
x2
=
torch
.
rand
(
*
other_shape
)
output_shape
=
list
(
torch
.
matmul
(
x1
,
x2
).
shape
)
# get matmul type
matmul_type
=
get_matmul_type
(
x1
.
dim
(),
x2
.
dim
())
model
=
MatMulModule
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"x1"
:
x1
.
to
(
'meta'
),
'x2'
:
x2
.
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
print
(
graph
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
mod_node
=
list
(
graph
.
nodes
)[
2
]
strategies_vector
=
StrategiesVector
(
mod_node
)
# build handler
handler
=
MatMulHandler
(
node
=
mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
logical_shape
is
not
None
assert
op_data
.
data
is
not
None
logical_input_shape
=
input_shape
logical_other_shape
=
other_shape
logical_output_shape
=
output_shape
if
matmul_type
==
MatMulType
.
MM
and
len
(
input_shape
)
==
1
:
logical_input_shape
=
[
1
]
+
input_shape
elif
matmul_type
==
MatMulType
.
BMM
:
logical_input_shape
,
logical_other_shape
,
logical_output_shape
=
_get_bmm_logical_shape
(
input_shape
,
other_shape
,
handler
.
transforms
)
else
:
logical_input_shape
=
input_shape
# check input operation data
assert
mapping
[
'input'
].
name
==
"x1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
(
input_shape
)
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
(
logical_input_shape
)
# check other operation data
assert
mapping
[
'other'
].
name
==
"x2"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
(
other_shape
)
assert
mapping
[
'other'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
(
logical_other_shape
)
# check output
assert
mapping
[
'output'
].
name
==
"matmul"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
(
output_shape
)
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
logical_shape
==
torch
.
Size
(
logical_output_shape
)
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# ensure there is no duplicate strategy
if
matmul_type
!=
MatMulType
.
BMM
:
assert
len
(
set
(
strategy_name_list
))
==
len
(
strategy_name_list
),
strategy_name_list
for
strategy
in
strategies_vector
:
strategy
:
ShardingStrategy
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x1'
)
other_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x2'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'matmul'
)
if
matmul_type
==
MatMulType
.
DOT
:
# dot product will produce a scaler
# results should fulfill:
# 1. the input and other operands have the same sharding spec
# 2. the output has no sharding
assert
input_sharding_spec
.
sharding_sequence
==
other_sharding_spec
.
sharding_sequence
assert
len
(
output_sharding_spec
.
sharding_sequence
)
==
0
elif
matmul_type
==
MatMulType
.
MV
:
# matrix-vector product should fulfill
# 1. the last dim of the input and other operands should have the same sharding
# 2. the first dim of the input and other should have the same sharding
# 3. the output should have only 1 dim
assert
input_sharding_spec
.
sharding_sequence
[
-
1
]
==
other_sharding_spec
.
sharding_sequence
[
-
1
]
assert
input_sharding_spec
.
sharding_sequence
[
0
]
==
output_sharding_spec
.
sharding_sequence
[
0
]
assert
len
(
output_sharding_spec
.
sharding_sequence
)
==
1
elif
matmul_type
==
MatMulType
.
MM
:
# matrix-matrix multiplication should fulfil
# 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding
# 2. the input's last dim and the first dim of the other should have the same sharding
# 3. the last dim of the output and other should have the same sharding
# 4. the input and output should have the same number of dims
if
len
(
input_shape
)
==
2
:
assert
input_sharding_spec
.
sharding_sequence
[
0
]
==
output_sharding_spec
.
sharding_sequence
[
0
]
assert
input_sharding_spec
.
sharding_sequence
[
-
1
]
==
other_sharding_spec
.
sharding_sequence
[
0
]
assert
output_sharding_spec
.
sharding_sequence
[
-
1
]
==
other_sharding_spec
.
sharding_sequence
[
-
1
]
assert
len
(
input_sharding_spec
.
sharding_sequence
)
==
len
(
output_sharding_spec
.
sharding_sequence
)
elif
matmul_type
==
MatMulType
.
BMM
:
# bmm should fulfil
# 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding
# 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding
# 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding
if
len
(
other_shape
)
>
1
:
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
if
len
(
input_shape
)
>
1
:
assert
input_sharding_spec
.
sharding_sequence
[
-
2
]
==
output_sharding_spec
.
sharding_sequence
[
-
2
]
if
len
(
other_shape
)
>
2
:
assert
other_sharding_spec
.
sharding_sequence
[
-
2
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
if
__name__
==
'__main__'
:
test_matmul_node_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment