Unverified Commit f3f19a5c authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] added matmul handler (#1763)

* [autoparallel] added matmul handler

* polish code
parent 4df01949
...@@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler ...@@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .layer_norm_handler import LayerNormModuleHandler from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler from .output_handler import OuputHandler
from .placeholder_handler import PlacehodlerHandler from .placeholder_handler import PlacehodlerHandler
...@@ -16,5 +17,5 @@ __all__ = [ ...@@ -16,5 +17,5 @@ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'operator_registry' 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry'
] ]
...@@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): ...@@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0] fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = sharded_input_shape * 2 bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost, bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost) total=fwd_compute_cost + bwd_compute_cost)
return compute_cost return compute_cost
@ignore_sharding_exception
def no_split(self): def no_split(self):
name = f'R = R dot R' name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
...@@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): ...@@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_one_dim(self, mesh_dim): def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}' name = f'R = S{mesh_dim} dot S{mesh_dim}'
...@@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): ...@@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
# do not split dimensions for dot product # do not split dimensions for dot product
...@@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): ...@@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool: def validate(self) -> bool:
input_op_data = self.op_data['input'] input_op_data = self.op_data['input']
other_op_data = self.op_data['other'] other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1 assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self): def no_split(self):
name = "R = R x R" name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}} dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={}) communication_action_mapping={})
@ignore_sharding_exception
def split_input_batch(self, mesh_dim): def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R' name = f'S{mesh_dim}R = S{mesh_dim}R x R'
# get sharding spec # get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}} dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {},
"output": {
0: [mesh_dim]
},
}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action # get communication action
communication_action_mapping = {}
if self.is_param('other'): if self.is_param('other'):
other_comm_action = self.get_communication_action( other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'], sharding_spec=sharding_spec_mapping['other'],
...@@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): ...@@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis=mesh_dim, logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE, comm_type=CommType.BEFORE,
arg_index=1) arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias: if self.has_bias:
if self.is_param('bias'): if self.is_param('bias'):
bias_comm_action = self.get_communication_action( bias_comm_action = self.get_communication_action(
...@@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): ...@@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis=mesh_dim, logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE, comm_type=CommType.BEFORE,
arg_index=2) arg_index=2)
communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action} communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
# no split # no split
...@@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool: def validate(self) -> bool:
input_op_data = self.op_data['input'] input_op_data = self.op_data['input']
other_op_data = self.op_data['other'] other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3 assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
if 'bias' in self.op_data: if 'bias' in self.op_data:
bias_op_data = self.op_data['bias'] bias_op_data = self.op_data['bias']
...@@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict = { dim_partition_dict = {
"input": { "input": {
0: [mesh_dim_0], 0: [mesh_dim_0],
-1: [mesh_dim_1] 2: [mesh_dim_1]
}, },
"other": { "other": {
0: [mesh_dim_0], 0: [mesh_dim_0],
-2: [mesh_dim_1] 1: [mesh_dim_1]
}, },
"bias": {}, "bias": {},
"output": { "output": {
......
...@@ -186,9 +186,14 @@ class StrategyGenerator(ABC): ...@@ -186,9 +186,14 @@ class StrategyGenerator(ABC):
""" """
op_data = self.op_data[key] op_data = self.op_data[key]
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device() sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
if len(sharded_shape) == 0:
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = self.op_data[key].data.dtype dtype = self.op_data[key].data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes return num_elements * size_per_elem_bytes
def generate(self) -> List[ShardingStrategy]: def generate(self) -> List[ShardingStrategy]:
""" """
......
...@@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: ...@@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
return dims[::-1] return dims[::-1]
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, def get_broadcast_dim_info(logical_shape, physical_shape):
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
# get the number of dimensions # get the number of dimensions
logical_num_dims = len(logical_shape) logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape) physical_num_dims = len(physical_shape)
...@@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe ...@@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
else: else:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
return logical_dim_broadcast_info
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
# get the broadcast info
logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape)
# generate the sharding spec for the physical shape # generate the sharding spec for the physical shape
physical_dim_partition = {} physical_dim_partition = {}
logical_dim_partition = logical_sharding_spec.dim_partition_dict logical_dim_partition = logical_sharding_spec.dim_partition_dict
......
import operator import operator
from copy import deepcopy from copy import deepcopy
from enum import Enum
from functools import reduce from functools import reduce
import torch import torch
...@@ -175,6 +174,9 @@ class ShardingSpec: ...@@ -175,6 +174,9 @@ class ShardingSpec:
dim_partition_dict=None, dim_partition_dict=None,
sharding_sequence=None): sharding_sequence=None):
self.device_mesh = device_mesh self.device_mesh = device_mesh
if isinstance(entire_shape, (list, tuple)):
entire_shape = torch.Size(entire_shape)
self.entire_shape = entire_shape self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence self.sharding_sequence = sharding_sequence
......
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import (
MatMulHandler,
MatMulType,
_get_bmm_logical_shape,
get_matmul_type,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.utils import parameterize
class MatMulModule(nn.Module):
def forward(self, x1, x2):
return torch.matmul(x1, x2)
@parameterize(
'tensor_shapes',
[
[[8], [8]], # dot product
[[4, 8], [8]], # mat-vec product
[[4, 8], [8, 16]], # mat-mat product
[[8], [8, 16]], # mat-mat product
[[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting
[[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting
[[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting
[[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting
[[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting
])
def test_matmul_node_handler(tensor_shapes):
input_shape, other_shape = tensor_shapes
# get output shape
x1 = torch.rand(*input_shape)
x2 = torch.rand(*other_shape)
output_shape = list(torch.matmul(x1, x2).shape)
# get matmul type
matmul_type = get_matmul_type(x1.dim(), x2.dim())
model = MatMulModule()
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
print(graph)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(mod_node)
# build handler
handler = MatMulHandler(node=mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
logical_input_shape = input_shape
logical_other_shape = other_shape
logical_output_shape = output_shape
if matmul_type == MatMulType.MM and len(input_shape) == 1:
logical_input_shape = [1] + input_shape
elif matmul_type == MatMulType.BMM:
logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape(
input_shape, other_shape, handler.transforms)
else:
logical_input_shape = input_shape
# check input operation data
assert mapping['input'].name == "x1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size(input_shape)
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size(logical_input_shape)
# check other operation data
assert mapping['other'].name == "x2"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size(other_shape)
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size(logical_other_shape)
# check output
assert mapping['output'].name == "matmul"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size(output_shape)
assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size(logical_output_shape)
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# ensure there is no duplicate strategy
if matmul_type != MatMulType.BMM:
assert len(set(strategy_name_list)) == len(strategy_name_list), strategy_name_list
for strategy in strategies_vector:
strategy: ShardingStrategy
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
output_sharding_spec = strategy.get_sharding_spec_by_name('matmul')
if matmul_type == MatMulType.DOT:
# dot product will produce a scaler
# results should fulfill:
# 1. the input and other operands have the same sharding spec
# 2. the output has no sharding
assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence
assert len(output_sharding_spec.sharding_sequence) == 0
elif matmul_type == MatMulType.MV:
# matrix-vector product should fulfill
# 1. the last dim of the input and other operands should have the same sharding
# 2. the first dim of the input and other should have the same sharding
# 3. the output should have only 1 dim
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert len(output_sharding_spec.sharding_sequence) == 1
elif matmul_type == MatMulType.MM:
# matrix-matrix multiplication should fulfil
# 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding
# 2. the input's last dim and the first dim of the other should have the same sharding
# 3. the last dim of the output and other should have the same sharding
# 4. the input and output should have the same number of dims
if len(input_shape) == 2:
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[0]
assert output_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
assert len(input_sharding_spec.sharding_sequence) == len(output_sharding_spec.sharding_sequence)
elif matmul_type == MatMulType.BMM:
# bmm should fulfil
# 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding
# 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding
# 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding
if len(other_shape) > 1:
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if len(input_shape) > 1:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
if len(other_shape) > 2:
assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1]
if __name__ == '__main__':
test_matmul_node_handler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment