Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f3f19a5c
Unverified
Commit
f3f19a5c
authored
Nov 01, 2022
by
Frank Lee
Committed by
GitHub
Nov 01, 2022
Browse files
[autoparallel] added matmul handler (#1763)
* [autoparallel] added matmul handler * polish code
parent
4df01949
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
725 additions
and
28 deletions
+725
-28
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+2
-1
colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
...auto_parallel/tensor_shard/node_handler/matmul_handler.py
+482
-0
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
..._shard/node_handler/strategy/matmul_strategy_generator.py
+40
-10
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
.../tensor_shard/node_handler/strategy/strategy_generator.py
+6
-1
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
+26
-15
colossalai/tensor/sharding_spec.py
colossalai/tensor/sharding_spec.py
+3
-1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py
...est_tensor_shard/test_node_handler/test_matmul_handler.py
+166
-0
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
f3f19a5c
...
@@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
...
@@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from
.conv_handler
import
ConvFunctionHandler
,
ConvModuleHandler
from
.conv_handler
import
ConvFunctionHandler
,
ConvModuleHandler
from
.layer_norm_handler
import
LayerNormModuleHandler
from
.layer_norm_handler
import
LayerNormModuleHandler
from
.linear_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
.linear_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
.matmul_handler
import
MatMulHandler
from
.normal_pooling_handler
import
NormPoolingHandler
from
.normal_pooling_handler
import
NormPoolingHandler
from
.output_handler
import
OuputHandler
from
.output_handler
import
OuputHandler
from
.placeholder_handler
import
PlacehodlerHandler
from
.placeholder_handler
import
PlacehodlerHandler
...
@@ -16,5 +17,5 @@ __all__ = [
...
@@ -16,5 +17,5 @@ __all__ = [
'LinearFunctionHandler'
,
'LinearModuleHandler'
,
'BMMFunctionHandler'
,
'AddBMMFunctionHandler'
,
'LinearFunctionHandler'
,
'LinearModuleHandler'
,
'BMMFunctionHandler'
,
'AddBMMFunctionHandler'
,
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'operator_registry'
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
]
]
colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
0 → 100644
View file @
f3f19a5c
import
operator
from
abc
import
ABC
,
abstractmethod
from
copy
import
deepcopy
from
enum
import
Enum
from
functools
import
reduce
from
typing
import
Dict
,
List
,
Union
import
torch
from
colossalai.auto_parallel.tensor_shard.utils.broadcast
import
(
BroadcastType
,
get_broadcast_dim_info
,
get_broadcast_shape
,
)
from
colossalai.tensor.sharding_spec
import
ShardingSpecException
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
from
..utils
import
recover_sharding_spec_for_broadcast_shape
from
.node_handler
import
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
(
BatchedMatMulStrategyGenerator
,
DotProductStrategyGenerator
,
LinearProjectionStrategyGenerator
,
MatVecStrategyGenerator
,
StrategyGenerator
,
)
class
MatMulType
(
Enum
):
"""
The MatMulType is categorized into 4 types based on the reference of torch.matmul
in https://pytorch.org/docs/stable/generated/torch.matmul.html.
DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements
MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
DOT
=
0
MM
=
1
MV
=
2
BMM
=
3
def
get_matmul_type
(
input_dim
:
int
,
other_dim
:
int
):
"""
Determine which type of matmul operation should be executed for the given tensor dimensions.
Args:
input_dim (int): the number of dimensions for the input tenosr
other_dim (int): the number of dimensions for the other tenosr
"""
if
input_dim
==
1
and
other_dim
==
1
:
matmul_type
=
MatMulType
.
DOT
elif
input_dim
in
[
1
,
2
]
and
other_dim
==
2
:
matmul_type
=
MatMulType
.
MM
elif
input_dim
==
2
and
other_dim
==
1
:
matmul_type
=
MatMulType
.
MV
elif
input_dim
>=
1
and
other_dim
>=
1
and
(
input_dim
>
2
or
other_dim
>
2
):
matmul_type
=
MatMulType
.
BMM
else
:
raise
ValueError
(
f
"The input and other tensors are of
{
input_dim
}
and
{
other_dim
}
which cannot used to execute matmul operation"
)
return
matmul_type
class
BmmTransform
(
ABC
):
"""
BmmTransform is an abstraction of the shape conversion between logical and physical operation data
during the strategy generation.
"""
@
abstractmethod
def
apply
(
self
,
shape_mapping
:
Dict
[
str
,
List
[
int
]]):
pass
@
abstractmethod
def
recover
(
self
,
op_data_mapping
:
Dict
[
str
,
OperationData
],
strategy
:
ShardingStrategy
):
pass
class
Padder
(
BmmTransform
):
"""
Add padding to the matrix dimensions for batched matrix multiplication.
"""
def
__init__
(
self
)
->
None
:
# keep the padding dim, op_name -> padded_dim
self
.
padded_dim_mapping
=
{}
def
apply
(
self
,
shape_mapping
:
Dict
[
str
,
List
[
int
]]):
mapping_copy
=
deepcopy
(
shape_mapping
)
input_shape
=
mapping_copy
[
'input'
]
other_shape
=
mapping_copy
[
'other'
]
if
len
(
input_shape
)
==
1
:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape
.
insert
(
0
,
1
)
self
.
padded_dim_mapping
[
'input'
]
=
-
2
self
.
padded_dim_mapping
[
'output'
]
=
-
2
elif
len
(
other_shape
)
==
1
:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape
=
other_shape
.
append
(
1
)
self
.
padded_dim_mapping
[
'other'
]
=
-
1
self
.
padded_dim_mapping
[
'output'
]
=
-
1
return
mapping_copy
def
recover
(
self
,
op_data_mapping
:
Dict
[
str
,
OperationData
],
strategy
:
ShardingStrategy
):
input_op_data
=
op_data_mapping
[
'input'
]
other_op_data
=
op_data_mapping
[
'other'
]
def
_remove_padded_dim
(
key
,
strategy
):
op_data
=
op_data_mapping
[
key
]
sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
op_data
.
name
)
tensor_shape
=
list
(
sharding_spec
.
entire_shape
)
dim_partition_list
=
[
None
]
*
len
(
tensor_shape
)
# padded dim is a negative number as the padded dim must be a matrix dim
padded_dim
=
self
.
padded_dim_mapping
[
key
]
# compute the new dim partition
for
tensor_dim
,
mesh_dims
in
sharding_spec
.
dim_partition_dict
.
items
():
dim_partition_list
[
tensor_dim
]
=
mesh_dims
dim_partition_list
.
pop
(
padded_dim
)
unpadded_dim_partition_list
=
{
k
:
v
for
k
,
v
in
enumerate
(
dim_partition_list
)
if
v
is
not
None
}
# compute unpadded tensor shape
tensor_shape
.
pop
(
padded_dim
)
assert
tensor_shape
==
list
(
op_data
.
data
.
shape
),
f
'
{
tensor_shape
}
vs
{
list
(
op_data
.
data
.
shape
)
}
'
# update sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
tensor_shape
,
unpadded_dim_partition_list
)
# enumerate all sharding strategies
strategies
=
[]
try
:
strategy_copy
=
strategy
.
clone
()
# only one of input and other will be padded
if
'input'
in
self
.
padded_dim_mapping
:
_remove_padded_dim
(
'input'
,
strategy_copy
)
_remove_padded_dim
(
'output'
,
strategy_copy
)
elif
'other'
in
self
.
padded_dim_mapping
:
_remove_padded_dim
(
'other'
,
strategy_copy
)
_remove_padded_dim
(
'output'
,
strategy_copy
)
strategies
.
append
(
strategy_copy
)
except
ShardingSpecException
as
e
:
pass
return
strategies
class
Broadcaster
(
BmmTransform
):
"""
Broadcast the non-matrix dimensions for batched matrix multiplication.
"""
def
__init__
(
self
)
->
None
:
self
.
broadcast_dim_info
=
{}
def
apply
(
self
,
shape_mapping
:
Dict
[
str
,
List
[
int
]]):
mapping_copy
=
shape_mapping
.
copy
()
# get shapes
input_shape
=
mapping_copy
[
'input'
]
other_shape
=
mapping_copy
[
'other'
]
# sanity check
assert
len
(
input_shape
)
>
1
and
len
(
other_shape
)
>
1
# broadcast the batch dim and record
bcast_non_matrix_dims
=
get_broadcast_shape
(
input_shape
[:
-
2
],
other_shape
[:
-
2
])
# store the broadcast dim info
input_broadcast_dim_info
=
get_broadcast_dim_info
(
bcast_non_matrix_dims
,
input_shape
[:
-
2
])
other_broadcast_dim_info
=
get_broadcast_dim_info
(
bcast_non_matrix_dims
,
other_shape
[:
-
2
])
self
.
broadcast_dim_info
[
'input'
]
=
input_broadcast_dim_info
self
.
broadcast_dim_info
[
'other'
]
=
other_broadcast_dim_info
# create the full logical shape
input_shape
=
bcast_non_matrix_dims
+
input_shape
[
-
2
:]
other_shape
=
bcast_non_matrix_dims
+
other_shape
[
-
2
:]
assert
len
(
input_shape
)
==
len
(
other_shape
)
mapping_copy
[
'input'
]
=
input_shape
mapping_copy
[
'other'
]
=
other_shape
return
mapping_copy
def
recover
(
self
,
op_data_mapping
:
Dict
[
str
,
OperationData
],
strategy
:
ShardingStrategy
):
# remove sharding on the broadcast dim
def
_remove_sharding_on_broadcast_dim
(
key
,
strategy
):
op_data
=
op_data_mapping
[
key
]
sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
op_data
.
name
)
tensor_shape
=
list
(
sharding_spec
.
entire_shape
)
for
dim_idx
,
broadcast_type
in
self
.
broadcast_dim_info
[
key
].
items
():
if
broadcast_type
==
BroadcastType
.
MULTIPLE
:
# if the dim is originally 1 and multiplied during broadcast
# we set its sharding to R
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
# the dim 0 of [1, 2, 4] is multiplied to 4
tensor_shape
[
dim_idx
]
=
1
elif
broadcast_type
==
BroadcastType
.
PADDDING
:
# if the dim is padded
# we remove its sharding
tensor_shape
[
dim_idx
]
=
None
tensor_shape_before_broadcast
=
[
dim
for
dim
in
tensor_shape
if
dim
is
not
None
]
physical_sharding_spec
=
recover_sharding_spec_for_broadcast_shape
(
logical_sharding_spec
=
sharding_spec
,
logical_shape
=
sharding_spec
.
entire_shape
,
physical_shape
=
tensor_shape_before_broadcast
)
strategy
.
sharding_specs
[
op_data
]
=
physical_sharding_spec
# enumerate all sharding strategies
strategies
=
[]
try
:
strategy_copy
=
strategy
.
clone
()
_remove_sharding_on_broadcast_dim
(
'input'
,
strategy_copy
)
_remove_sharding_on_broadcast_dim
(
'other'
,
strategy_copy
)
strategies
.
append
(
strategy_copy
)
except
ShardingSpecException
as
e
:
pass
return
strategies
class
Viewer
(
BmmTransform
):
"""
Change the shape of the tensor from N-D to 3D
"""
def
__init__
(
self
)
->
None
:
self
.
batch_dims_before_view
=
None
def
apply
(
self
,
shape_mapping
:
Dict
[
str
,
List
[
int
]]):
mapping_copy
=
shape_mapping
.
copy
()
self
.
batch_dims_before_view
=
list
(
mapping_copy
[
'input'
][:
-
2
])
# get shapes
input_shape
=
shape_mapping
[
'input'
]
other_shape
=
shape_mapping
[
'other'
]
# view to 3d tensor
assert
len
(
input_shape
)
>=
3
and
len
(
other_shape
)
>=
3
input_shape
=
[
reduce
(
operator
.
mul
,
input_shape
[:
-
2
])]
+
input_shape
[
-
2
:]
other_shape
=
[
reduce
(
operator
.
mul
,
other_shape
[:
-
2
])]
+
other_shape
[
-
2
:]
output_shape
=
input_shape
[:
2
]
+
other_shape
[
2
:]
mapping_copy
[
'input'
]
=
input_shape
mapping_copy
[
'other'
]
=
other_shape
mapping_copy
[
'output'
]
=
output_shape
return
mapping_copy
def
recover
(
self
,
op_data_mapping
:
Dict
[
str
,
OperationData
],
strategy
:
ShardingStrategy
):
# get operation data
def
_update_sharding_spec
(
key
,
strategy
,
physical_batch_dim
):
"""
Map the logical batch dim to the physical batch dim
"""
op_data
=
op_data_mapping
[
key
]
sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
op_data
.
name
)
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
entire_shape
=
sharding_spec
.
entire_shape
# upddate the dimension index for the matrix dimensions
if
2
in
dim_partition_dict
:
dim_partition_dict
[
len
(
self
.
batch_dims_before_view
)
+
1
]
=
dim_partition_dict
.
pop
(
2
)
if
1
in
dim_partition_dict
:
dim_partition_dict
[
len
(
self
.
batch_dims_before_view
)]
=
dim_partition_dict
.
pop
(
1
)
# map the logical batch dim to phyiscal batch dim
if
0
in
dim_partition_dict
:
batch_dim_shard
=
dim_partition_dict
.
pop
(
0
)
dim_partition_dict
[
physical_batch_dim
]
=
batch_dim_shard
# the new shape will be the batch dims + the last 2 matrix dims
shape_before_view
=
self
.
batch_dims_before_view
+
list
(
entire_shape
[
-
2
:])
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
shape_before_view
,
dim_partition_dict
)
num_batch_dim_before_view
=
len
(
self
.
batch_dims_before_view
)
# enumerate all sharding strategies
strategies
=
[]
for
i
in
range
(
num_batch_dim_before_view
):
# create a new strategy
strategy_copy
=
strategy
.
clone
()
try
:
_update_sharding_spec
(
'input'
,
strategy_copy
,
i
)
_update_sharding_spec
(
'other'
,
strategy_copy
,
i
)
_update_sharding_spec
(
'output'
,
strategy_copy
,
i
)
strategies
.
append
(
strategy_copy
)
except
ShardingSpecException
as
e
:
continue
return
strategies
def
_get_bmm_logical_shape
(
input_shape
,
other_shape
,
transforms
):
"""
Compute the logical shapes for BMM operation. BMM has a general representation
[b, i, k] = [b, i, j] x [b, j, k]
The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions
The logical shape for the bmm operands will undergo three stages
1. append/prepend the 1 to the 1D tensor if there is any
2. broadcast the non-matrix dimensions
3. reshape to 3 dimensions
"""
shape_mapping
=
{
'input'
:
input_shape
,
'other'
:
other_shape
}
for
transform
in
transforms
:
shape_mapping
=
transform
.
apply
(
shape_mapping
)
input_shape
=
shape_mapping
.
get
(
'input'
,
None
)
other_shape
=
shape_mapping
.
get
(
'other'
,
None
)
output_shape
=
shape_mapping
.
get
(
'output'
,
None
)
return
input_shape
,
other_shape
,
output_shape
@
operator_registry
.
register
(
torch
.
matmul
)
@
operator_registry
.
register
(
torch
.
Tensor
.
matmul
)
class
MatMulHandler
(
NodeHandler
):
"""
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
the operands.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
# check which type of operation this matmul will call
self
.
input_meta_data
=
self
.
node
.
args
[
0
].
_meta_data
self
.
other_meta_data
=
self
.
node
.
args
[
1
].
_meta_data
self
.
output_meta_data
=
self
.
node
.
_meta_data
input_dim
=
self
.
input_meta_data
.
dim
()
other_dim
=
self
.
other_meta_data
.
dim
()
self
.
matmul_type
=
get_matmul_type
(
input_dim
,
other_dim
)
if
self
.
matmul_type
==
MatMulType
.
BMM
:
# bmm operation can possibly involve padding, broadcasting and view
# these transforms will be used to create logical shape and
# recover physical sharding spec
self
.
transforms
=
[
Padder
(),
Broadcaster
(),
Viewer
()]
else
:
self
.
transforms
=
None
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
generators
=
[]
op_data_mapping
=
self
.
get_operation_data_mapping
()
if
self
.
matmul_type
==
MatMulType
.
BMM
:
generators
.
append
(
BatchedMatMulStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
elif
self
.
matmul_type
==
MatMulType
.
DOT
:
generators
.
append
(
DotProductStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
elif
self
.
matmul_type
==
MatMulType
.
MV
:
generators
.
append
(
MatVecStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
elif
self
.
matmul_type
==
MatMulType
.
MM
:
generators
.
append
(
LinearProjectionStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
logical_shape_func
=
{
MatMulType
.
DOT
:
self
.
_get_logical_shape_for_dot
,
MatMulType
.
MM
:
self
.
_get_logical_shape_for_mm
,
MatMulType
.
MV
:
self
.
_get_logical_shape_for_mv
,
MatMulType
.
BMM
:
self
.
_get_logical_shape_for_bmm
}
logical_shapes
=
logical_shape_func
[
self
.
matmul_type
]()
op_data_mapping
=
self
.
_get_op_data_mapping
(
*
logical_shapes
)
return
op_data_mapping
def
_get_op_data_mapping
(
self
,
input_logical_shape
,
other_logical_shape
,
output_logical_shape
):
# convert list to torch.Size
if
input_logical_shape
:
input_logical_shape
=
torch
.
Size
(
input_logical_shape
)
if
other_logical_shape
:
other_logical_shape
=
torch
.
Size
(
other_logical_shape
)
if
output_logical_shape
:
output_logical_shape
=
torch
.
Size
(
output_logical_shape
)
# create op data
input_op_data
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
input_meta_data
,
logical_shape
=
input_logical_shape
)
other_op_data
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
1
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
other_meta_data
,
logical_shape
=
other_logical_shape
)
output_op_data
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
output_meta_data
,
logical_shape
=
output_logical_shape
)
mapping
=
{
'input'
:
input_op_data
,
'other'
:
other_op_data
,
'output'
:
output_op_data
}
return
mapping
def
_get_logical_shape_for_dot
(
self
):
"""
The operands for the dot operation have the same logical shape as the physical shape
"""
return
None
,
None
,
None
def
_get_logical_shape_for_mm
(
self
):
"""
We need to handle the input tensor for a matrix-matrix multiplcation as the input
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
(e.g. [4] -> [1, 4]).
"""
if
self
.
input_meta_data
.
dim
()
==
1
:
input_logical_shape
=
[
1
]
+
list
(
self
.
input_meta_data
.
shape
)
input_logical_shape
=
torch
.
Size
(
input_logical_shape
)
else
:
input_logical_shape
=
None
return
input_logical_shape
,
None
,
None
def
_get_logical_shape_for_mv
(
self
):
"""
No broadcasting or dim insertion occurs for matrix-vector operation.
"""
return
None
,
None
,
None
def
_get_logical_shape_for_bmm
(
self
):
input_physical_shape
=
list
(
self
.
input_meta_data
.
shape
)
other_physical_shape
=
list
(
self
.
other_meta_data
.
shape
)
return
_get_bmm_logical_shape
(
input_physical_shape
,
other_physical_shape
,
self
.
transforms
)
def
post_process
(
self
,
strategy
:
ShardingStrategy
)
->
Union
[
ShardingStrategy
,
List
[
ShardingStrategy
]]:
if
self
.
matmul_type
in
[
MatMulType
.
DOT
,
MatMulType
.
MV
]:
return
strategy
elif
self
.
matmul_type
==
MatMulType
.
MM
:
if
self
.
input_meta_data
.
dim
()
==
1
:
# if a 1 is prepended to the input shape (this occurs when input is a 1D tensor)
# we need to remove that dim
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
str
(
self
.
node
.
args
[
0
]))
input_physical_shape
=
self
.
node
.
args
[
0
].
_meta_data
.
shape
dim_partition_dict
=
input_sharding_spec
.
dim_partition_dict
# remove the partitioning in the dim 0
if
0
in
dim_partition_dict
:
dim_partition_dict
.
pop
(
0
,
None
)
# move the partitioning in dim 1 to dim 0
if
-
1
in
dim_partition_dict
:
shard
=
dim_partition_dict
.
pop
(
-
1
)
dim_partition_dict
[
0
]
=
shard
# re-init the sharding spec
input_sharding_spec
.
__init__
(
input_sharding_spec
.
device_mesh
,
entire_shape
=
input_physical_shape
,
dim_partition_dict
=
dim_partition_dict
)
return
strategy
else
:
return
strategy
elif
self
.
matmul_type
==
MatMulType
.
BMM
:
op_data_mapping
=
self
.
get_operation_data_mapping
()
strategies
=
[
strategy
]
# recover the physical sharding spec
for
transform
in
self
.
transforms
[::
-
1
]:
recovered_stragies
=
[]
for
strategy_
in
strategies
:
output
=
transform
.
recover
(
op_data_mapping
,
strategy_
)
if
isinstance
(
output
,
ShardingStrategy
):
recovered_stragies
.
append
(
output
)
elif
isinstance
(
output
,
(
list
,
tuple
)):
recovered_stragies
.
extend
(
output
)
else
:
raise
TypeError
(
f
"Found unexpected output type
{
type
(
output
)
}
from the recover method of BmmTransform"
)
strategies
=
recovered_stragies
return
strategies
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
View file @
f3f19a5c
...
@@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
...
@@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'input'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'input'
]].
get_sharded_shape_per_device
()
fwd_compute_cost
=
sharded_input_shape
[
0
]
fwd_compute_cost
=
sharded_input_shape
[
0
]
bwd_compute_cost
=
sharded_input_shape
*
2
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
total
=
fwd_compute_cost
+
bwd_compute_cost
)
return
compute_cost
return
compute_cost
@
ignore_sharding_exception
def
no_split
(
self
):
def
no_split
(
self
):
name
=
f
'R = R dot R'
name
=
f
'R = R dot R'
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{},
'bias'
:
{}}
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{},
'bias'
:
{}}
...
@@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
...
@@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
@
ignore_sharding_exception
def
split_one_dim
(
self
,
mesh_dim
):
def
split_one_dim
(
self
,
mesh_dim
):
name
=
f
'R = S
{
mesh_dim
}
dot S
{
mesh_dim
}
'
name
=
f
'R = S
{
mesh_dim
}
dot S
{
mesh_dim
}
'
...
@@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
...
@@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
# do not split dimensions for dot product
# do not split dimensions for dot product
...
@@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
...
@@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
input_op_data
=
self
.
op_data
[
'input'
]
input_op_data
=
self
.
op_data
[
'input'
]
other_op_data
=
self
.
op_data
[
'other'
]
other_op_data
=
self
.
op_data
[
'other'
]
assert
input_op_data
.
data
.
dim
()
>
1
and
other_op_data
.
data
.
dim
()
==
1
assert
input_op_data
.
data
.
dim
()
==
2
and
other_op_data
.
data
.
dim
()
==
1
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'input'
]].
get_sharded_shape_per_device
()
fwd_compute_cost
=
sharded_input_shape
[
0
]
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
return
compute_cost
@
ignore_sharding_exception
def
no_split
(
self
):
def
no_split
(
self
):
name
=
"R = R x R"
name
=
"R = R x R"
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{},
"bias"
:
{}}
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{}}
if
self
.
has_bias
:
dim_partition_dict
[
'bias'
]
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
{})
communication_action_mapping
=
{})
@
ignore_sharding_exception
def
split_input_batch
(
self
,
mesh_dim
):
def
split_input_batch
(
self
,
mesh_dim
):
name
=
f
'S
{
mesh_dim
}
R = S
{
mesh_dim
}
R x R'
name
=
f
'S
{
mesh_dim
}
R = S
{
mesh_dim
}
R x R'
# get sharding spec
# get sharding spec
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{},
"output"
:
{
0
:
[
mesh_dim
]},
"bias"
:
{}}
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]
},
"other"
:
{},
"output"
:
{
0
:
[
mesh_dim
]
},
}
if
self
.
has_bias
:
dim_partition_dict
[
'bias'
]
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication action
# get communication action
communication_action_mapping
=
{}
if
self
.
is_param
(
'other'
):
if
self
.
is_param
(
'other'
):
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'other'
],
sharding_spec
=
sharding_spec_mapping
[
'other'
],
...
@@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
...
@@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis
=
mesh_dim
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
)
communication_action_mapping
[
'other'
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
if
self
.
is_param
(
'bias'
):
if
self
.
is_param
(
'bias'
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
...
@@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
...
@@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis
=
mesh_dim
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
2
)
arg_index
=
2
)
communication_action_mapping
=
{
'other'
:
other_comm_action
,
'bias'
:
bias_comm_action
}
communication_action_mapping
[
'bias'
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
# no split
# no split
...
@@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
input_op_data
=
self
.
op_data
[
'input'
]
input_op_data
=
self
.
op_data
[
'input'
]
other_op_data
=
self
.
op_data
[
'other'
]
other_op_data
=
self
.
op_data
[
'other'
]
assert
input_op_data
.
data
.
dim
(
)
==
3
or
other_op_data
.
data
.
dim
(
)
==
3
assert
len
(
input_op_data
.
logical_shape
)
==
3
or
len
(
other_op_data
.
logical_shape
)
==
3
if
'bias'
in
self
.
op_data
:
if
'bias'
in
self
.
op_data
:
bias_op_data
=
self
.
op_data
[
'bias'
]
bias_op_data
=
self
.
op_data
[
'bias'
]
...
@@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict
=
{
dim_partition_dict
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
],
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
2
:
[
mesh_dim_1
]
},
},
"other"
:
{
"other"
:
{
0
:
[
mesh_dim_0
],
0
:
[
mesh_dim_0
],
-
2
:
[
mesh_dim_1
]
1
:
[
mesh_dim_1
]
},
},
"bias"
:
{},
"bias"
:
{},
"output"
:
{
"output"
:
{
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
View file @
f3f19a5c
...
@@ -186,9 +186,14 @@ class StrategyGenerator(ABC):
...
@@ -186,9 +186,14 @@ class StrategyGenerator(ABC):
"""
"""
op_data
=
self
.
op_data
[
key
]
op_data
=
self
.
op_data
[
key
]
sharded_shape
=
strategy
.
sharding_specs
[
op_data
].
get_sharded_shape_per_device
()
sharded_shape
=
strategy
.
sharding_specs
[
op_data
].
get_sharded_shape_per_device
()
if
len
(
sharded_shape
)
==
0
:
num_elements
=
1
else
:
num_elements
=
reduce
(
operator
.
mul
,
sharded_shape
)
dtype
=
self
.
op_data
[
key
].
data
.
dtype
dtype
=
self
.
op_data
[
key
].
data
.
dtype
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
reduce
(
operator
.
mul
,
sharded_shape
)
*
size_per_elem_bytes
return
num_elements
*
size_per_elem_bytes
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
"""
"""
...
...
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
View file @
f3f19a5c
...
@@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
...
@@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
return
dims
[::
-
1
]
return
dims
[::
-
1
]
def
recover_sharding_spec_for_broadcast_shape
(
logical_sharding_spec
:
ShardingSpec
,
logical_shape
:
torch
.
Size
,
def
get_broadcast_dim_info
(
logical_shape
,
physical_shape
):
physical_shape
:
torch
.
Size
)
->
ShardingSpec
:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if
list
(
logical_shape
)
==
list
(
physical_shape
):
return
logical_sharding_spec
# get the number of dimensions
# get the number of dimensions
logical_num_dims
=
len
(
logical_shape
)
logical_num_dims
=
len
(
logical_shape
)
physical_num_dims
=
len
(
physical_shape
)
physical_num_dims
=
len
(
physical_shape
)
...
@@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
...
@@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
else
:
else
:
logical_dim_broadcast_info
[
logical_dim_idx
]
=
BroadcastType
.
PADDDING
logical_dim_broadcast_info
[
logical_dim_idx
]
=
BroadcastType
.
PADDDING
return
logical_dim_broadcast_info
def
recover_sharding_spec_for_broadcast_shape
(
logical_sharding_spec
:
ShardingSpec
,
logical_shape
:
torch
.
Size
,
physical_shape
:
torch
.
Size
)
->
ShardingSpec
:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if
list
(
logical_shape
)
==
list
(
physical_shape
):
return
logical_sharding_spec
# get the number of dimensions
logical_num_dims
=
len
(
logical_shape
)
physical_num_dims
=
len
(
physical_shape
)
# get the broadcast info
logical_dim_broadcast_info
=
get_broadcast_dim_info
(
logical_shape
,
physical_shape
)
# generate the sharding spec for the physical shape
# generate the sharding spec for the physical shape
physical_dim_partition
=
{}
physical_dim_partition
=
{}
logical_dim_partition
=
logical_sharding_spec
.
dim_partition_dict
logical_dim_partition
=
logical_sharding_spec
.
dim_partition_dict
...
...
colossalai/tensor/sharding_spec.py
View file @
f3f19a5c
import
operator
import
operator
from
copy
import
deepcopy
from
copy
import
deepcopy
from
enum
import
Enum
from
functools
import
reduce
from
functools
import
reduce
import
torch
import
torch
...
@@ -175,6 +174,9 @@ class ShardingSpec:
...
@@ -175,6 +174,9 @@ class ShardingSpec:
dim_partition_dict
=
None
,
dim_partition_dict
=
None
,
sharding_sequence
=
None
):
sharding_sequence
=
None
):
self
.
device_mesh
=
device_mesh
self
.
device_mesh
=
device_mesh
if
isinstance
(
entire_shape
,
(
list
,
tuple
)):
entire_shape
=
torch
.
Size
(
entire_shape
)
self
.
entire_shape
=
entire_shape
self
.
entire_shape
=
entire_shape
self
.
dim_partition_dict
=
dim_partition_dict
self
.
dim_partition_dict
=
dim_partition_dict
self
.
sharding_sequence
=
sharding_sequence
self
.
sharding_sequence
=
sharding_sequence
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py
0 → 100644
View file @
f3f19a5c
import
torch
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler
import
(
MatMulHandler
,
MatMulType
,
_get_bmm_logical_shape
,
get_matmul_type
,
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing.utils
import
parameterize
class
MatMulModule
(
nn
.
Module
):
def
forward
(
self
,
x1
,
x2
):
return
torch
.
matmul
(
x1
,
x2
)
@
parameterize
(
'tensor_shapes'
,
[
[[
8
],
[
8
]],
# dot product
[[
4
,
8
],
[
8
]],
# mat-vec product
[[
4
,
8
],
[
8
,
16
]],
# mat-mat product
[[
8
],
[
8
,
16
]],
# mat-mat product
[[
8
],
[
4
,
8
,
16
]],
# batched mat-mat product with padding + broadcasting
[[
4
,
8
,
16
],
[
16
]],
# batched mat-mat product with padding + broadcasting
[[
4
,
8
,
16
],
[
16
,
32
]],
# batched mat-mat product with broadcasting
[[
4
,
8
,
16
],
[
1
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
4
,
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
1
,
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
1
,
4
,
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
2
,
1
,
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product with broadcasting
[[
2
,
4
,
8
,
16
],
[
2
,
4
,
16
,
32
]],
# batched mat-mat product without broadcasting
])
def
test_matmul_node_handler
(
tensor_shapes
):
input_shape
,
other_shape
=
tensor_shapes
# get output shape
x1
=
torch
.
rand
(
*
input_shape
)
x2
=
torch
.
rand
(
*
other_shape
)
output_shape
=
list
(
torch
.
matmul
(
x1
,
x2
).
shape
)
# get matmul type
matmul_type
=
get_matmul_type
(
x1
.
dim
(),
x2
.
dim
())
model
=
MatMulModule
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"x1"
:
x1
.
to
(
'meta'
),
'x2'
:
x2
.
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
print
(
graph
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
mod_node
=
list
(
graph
.
nodes
)[
2
]
strategies_vector
=
StrategiesVector
(
mod_node
)
# build handler
handler
=
MatMulHandler
(
node
=
mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
logical_shape
is
not
None
assert
op_data
.
data
is
not
None
logical_input_shape
=
input_shape
logical_other_shape
=
other_shape
logical_output_shape
=
output_shape
if
matmul_type
==
MatMulType
.
MM
and
len
(
input_shape
)
==
1
:
logical_input_shape
=
[
1
]
+
input_shape
elif
matmul_type
==
MatMulType
.
BMM
:
logical_input_shape
,
logical_other_shape
,
logical_output_shape
=
_get_bmm_logical_shape
(
input_shape
,
other_shape
,
handler
.
transforms
)
else
:
logical_input_shape
=
input_shape
# check input operation data
assert
mapping
[
'input'
].
name
==
"x1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
(
input_shape
)
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
(
logical_input_shape
)
# check other operation data
assert
mapping
[
'other'
].
name
==
"x2"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
(
other_shape
)
assert
mapping
[
'other'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
(
logical_other_shape
)
# check output
assert
mapping
[
'output'
].
name
==
"matmul"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
(
output_shape
)
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
logical_shape
==
torch
.
Size
(
logical_output_shape
)
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# ensure there is no duplicate strategy
if
matmul_type
!=
MatMulType
.
BMM
:
assert
len
(
set
(
strategy_name_list
))
==
len
(
strategy_name_list
),
strategy_name_list
for
strategy
in
strategies_vector
:
strategy
:
ShardingStrategy
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x1'
)
other_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x2'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'matmul'
)
if
matmul_type
==
MatMulType
.
DOT
:
# dot product will produce a scaler
# results should fulfill:
# 1. the input and other operands have the same sharding spec
# 2. the output has no sharding
assert
input_sharding_spec
.
sharding_sequence
==
other_sharding_spec
.
sharding_sequence
assert
len
(
output_sharding_spec
.
sharding_sequence
)
==
0
elif
matmul_type
==
MatMulType
.
MV
:
# matrix-vector product should fulfill
# 1. the last dim of the input and other operands should have the same sharding
# 2. the first dim of the input and other should have the same sharding
# 3. the output should have only 1 dim
assert
input_sharding_spec
.
sharding_sequence
[
-
1
]
==
other_sharding_spec
.
sharding_sequence
[
-
1
]
assert
input_sharding_spec
.
sharding_sequence
[
0
]
==
output_sharding_spec
.
sharding_sequence
[
0
]
assert
len
(
output_sharding_spec
.
sharding_sequence
)
==
1
elif
matmul_type
==
MatMulType
.
MM
:
# matrix-matrix multiplication should fulfil
# 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding
# 2. the input's last dim and the first dim of the other should have the same sharding
# 3. the last dim of the output and other should have the same sharding
# 4. the input and output should have the same number of dims
if
len
(
input_shape
)
==
2
:
assert
input_sharding_spec
.
sharding_sequence
[
0
]
==
output_sharding_spec
.
sharding_sequence
[
0
]
assert
input_sharding_spec
.
sharding_sequence
[
-
1
]
==
other_sharding_spec
.
sharding_sequence
[
0
]
assert
output_sharding_spec
.
sharding_sequence
[
-
1
]
==
other_sharding_spec
.
sharding_sequence
[
-
1
]
assert
len
(
input_sharding_spec
.
sharding_sequence
)
==
len
(
output_sharding_spec
.
sharding_sequence
)
elif
matmul_type
==
MatMulType
.
BMM
:
# bmm should fulfil
# 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding
# 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding
# 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding
if
len
(
other_shape
)
>
1
:
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
if
len
(
input_shape
)
>
1
:
assert
input_sharding_spec
.
sharding_sequence
[
-
2
]
==
output_sharding_spec
.
sharding_sequence
[
-
2
]
if
len
(
other_shape
)
>
2
:
assert
other_sharding_spec
.
sharding_sequence
[
-
2
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
if
__name__
==
'__main__'
:
test_matmul_node_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment