Unverified Commit f3f19a5c authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] added matmul handler (#1763)

* [autoparallel] added matmul handler

* polish code
parent 4df01949
......@@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler
from .placeholder_handler import PlacehodlerHandler
......@@ -16,5 +17,5 @@ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'operator_registry'
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry'
]
......@@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = sharded_input_shape * 2
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
......@@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}'
......@@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# do not split dimensions for dot product
......@@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@ignore_sharding_exception
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}}
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {},
"output": {
0: [mesh_dim]
},
}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
communication_action_mapping = {}
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
......@@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
......@@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=2)
communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action}
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# no split
......@@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
if 'bias' in self.op_data:
bias_op_data = self.op_data['bias']
......@@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
1: [mesh_dim_1]
},
"bias": {},
"output": {
......
......@@ -186,9 +186,14 @@ class StrategyGenerator(ABC):
"""
op_data = self.op_data[key]
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
if len(sharded_shape) == 0:
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = self.op_data[key].data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
return num_elements * size_per_elem_bytes
def generate(self) -> List[ShardingStrategy]:
"""
......
......@@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
return dims[::-1]
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
def get_broadcast_dim_info(logical_shape, physical_shape):
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
......@@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
else:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
return logical_dim_broadcast_info
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
# get the broadcast info
logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape)
# generate the sharding spec for the physical shape
physical_dim_partition = {}
logical_dim_partition = logical_sharding_spec.dim_partition_dict
......
import operator
from copy import deepcopy
from enum import Enum
from functools import reduce
import torch
......@@ -175,6 +174,9 @@ class ShardingSpec:
dim_partition_dict=None,
sharding_sequence=None):
self.device_mesh = device_mesh
if isinstance(entire_shape, (list, tuple)):
entire_shape = torch.Size(entire_shape)
self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence
......
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import (
MatMulHandler,
MatMulType,
_get_bmm_logical_shape,
get_matmul_type,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.utils import parameterize
class MatMulModule(nn.Module):
def forward(self, x1, x2):
return torch.matmul(x1, x2)
@parameterize(
'tensor_shapes',
[
[[8], [8]], # dot product
[[4, 8], [8]], # mat-vec product
[[4, 8], [8, 16]], # mat-mat product
[[8], [8, 16]], # mat-mat product
[[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting
[[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting
[[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting
[[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting
[[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting
])
def test_matmul_node_handler(tensor_shapes):
input_shape, other_shape = tensor_shapes
# get output shape
x1 = torch.rand(*input_shape)
x2 = torch.rand(*other_shape)
output_shape = list(torch.matmul(x1, x2).shape)
# get matmul type
matmul_type = get_matmul_type(x1.dim(), x2.dim())
model = MatMulModule()
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
print(graph)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(mod_node)
# build handler
handler = MatMulHandler(node=mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
logical_input_shape = input_shape
logical_other_shape = other_shape
logical_output_shape = output_shape
if matmul_type == MatMulType.MM and len(input_shape) == 1:
logical_input_shape = [1] + input_shape
elif matmul_type == MatMulType.BMM:
logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape(
input_shape, other_shape, handler.transforms)
else:
logical_input_shape = input_shape
# check input operation data
assert mapping['input'].name == "x1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size(input_shape)
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size(logical_input_shape)
# check other operation data
assert mapping['other'].name == "x2"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size(other_shape)
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size(logical_other_shape)
# check output
assert mapping['output'].name == "matmul"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size(output_shape)
assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size(logical_output_shape)
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# ensure there is no duplicate strategy
if matmul_type != MatMulType.BMM:
assert len(set(strategy_name_list)) == len(strategy_name_list), strategy_name_list
for strategy in strategies_vector:
strategy: ShardingStrategy
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
output_sharding_spec = strategy.get_sharding_spec_by_name('matmul')
if matmul_type == MatMulType.DOT:
# dot product will produce a scaler
# results should fulfill:
# 1. the input and other operands have the same sharding spec
# 2. the output has no sharding
assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence
assert len(output_sharding_spec.sharding_sequence) == 0
elif matmul_type == MatMulType.MV:
# matrix-vector product should fulfill
# 1. the last dim of the input and other operands should have the same sharding
# 2. the first dim of the input and other should have the same sharding
# 3. the output should have only 1 dim
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert len(output_sharding_spec.sharding_sequence) == 1
elif matmul_type == MatMulType.MM:
# matrix-matrix multiplication should fulfil
# 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding
# 2. the input's last dim and the first dim of the other should have the same sharding
# 3. the last dim of the output and other should have the same sharding
# 4. the input and output should have the same number of dims
if len(input_shape) == 2:
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[0]
assert output_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
assert len(input_sharding_spec.sharding_sequence) == len(output_sharding_spec.sharding_sequence)
elif matmul_type == MatMulType.BMM:
# bmm should fulfil
# 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding
# 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding
# 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding
if len(other_shape) > 1:
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if len(input_shape) > 1:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
if len(other_shape) > 2:
assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1]
if __name__ == '__main__':
test_matmul_node_handler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment