Unverified Commit f3f19a5c authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] added matmul handler (#1763)

* [autoparallel] added matmul handler

* polish code
parent 4df01949
...@@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler ...@@ -4,6 +4,7 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .layer_norm_handler import LayerNormModuleHandler from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler from .output_handler import OuputHandler
from .placeholder_handler import PlacehodlerHandler from .placeholder_handler import PlacehodlerHandler
...@@ -16,5 +17,5 @@ __all__ = [ ...@@ -16,5 +17,5 @@ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'operator_registry' 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry'
] ]
import operator
from abc import ABC, abstractmethod
from copy import deepcopy
from enum import Enum
from functools import reduce
from typing import Dict, List, Union
import torch
from colossalai.auto_parallel.tensor_shard.utils.broadcast import (
BroadcastType,
get_broadcast_dim_info,
get_broadcast_shape,
)
from colossalai.tensor.sharding_spec import ShardingSpecException
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
DotProductStrategyGenerator,
LinearProjectionStrategyGenerator,
MatVecStrategyGenerator,
StrategyGenerator,
)
class MatMulType(Enum):
"""
The MatMulType is categorized into 4 types based on the reference of torch.matmul
in https://pytorch.org/docs/stable/generated/torch.matmul.html.
DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements
MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
DOT = 0
MM = 1
MV = 2
BMM = 3
def get_matmul_type(input_dim: int, other_dim: int):
"""
Determine which type of matmul operation should be executed for the given tensor dimensions.
Args:
input_dim (int): the number of dimensions for the input tenosr
other_dim (int): the number of dimensions for the other tenosr
"""
if input_dim == 1 and other_dim == 1:
matmul_type = MatMulType.DOT
elif input_dim in [1, 2] and other_dim == 2:
matmul_type = MatMulType.MM
elif input_dim == 2 and other_dim == 1:
matmul_type = MatMulType.MV
elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2):
matmul_type = MatMulType.BMM
else:
raise ValueError(
f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation"
)
return matmul_type
class BmmTransform(ABC):
"""
BmmTransform is an abstraction of the shape conversion between logical and physical operation data
during the strategy generation.
"""
@abstractmethod
def apply(self, shape_mapping: Dict[str, List[int]]):
pass
@abstractmethod
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
pass
class Padder(BmmTransform):
"""
Add padding to the matrix dimensions for batched matrix multiplication.
"""
def __init__(self) -> None:
# keep the padding dim, op_name -> padded_dim
self.padded_dim_mapping = {}
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = deepcopy(shape_mapping)
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
if len(input_shape) == 1:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape.insert(0, 1)
self.padded_dim_mapping['input'] = -2
self.padded_dim_mapping['output'] = -2
elif len(other_shape) == 1:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape = other_shape.append(1)
self.padded_dim_mapping['other'] = -1
self.padded_dim_mapping['output'] = -1
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
input_op_data = op_data_mapping['input']
other_op_data = op_data_mapping['other']
def _remove_padded_dim(key, strategy):
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
tensor_shape = list(sharding_spec.entire_shape)
dim_partition_list = [None] * len(tensor_shape)
# padded dim is a negative number as the padded dim must be a matrix dim
padded_dim = self.padded_dim_mapping[key]
# compute the new dim partition
for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items():
dim_partition_list[tensor_dim] = mesh_dims
dim_partition_list.pop(padded_dim)
unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None}
# compute unpadded tensor shape
tensor_shape.pop(padded_dim)
assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
# update sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
# only one of input and other will be padded
if 'input' in self.padded_dim_mapping:
_remove_padded_dim('input', strategy_copy)
_remove_padded_dim('output', strategy_copy)
elif 'other' in self.padded_dim_mapping:
_remove_padded_dim('other', strategy_copy)
_remove_padded_dim('output', strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
pass
return strategies
class Broadcaster(BmmTransform):
"""
Broadcast the non-matrix dimensions for batched matrix multiplication.
"""
def __init__(self) -> None:
self.broadcast_dim_info = {}
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
# get shapes
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
# sanity check
assert len(input_shape) > 1 and len(other_shape) > 1
# broadcast the batch dim and record
bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2])
# store the broadcast dim info
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
self.broadcast_dim_info['input'] = input_broadcast_dim_info
self.broadcast_dim_info['other'] = other_broadcast_dim_info
# create the full logical shape
input_shape = bcast_non_matrix_dims + input_shape[-2:]
other_shape = bcast_non_matrix_dims + other_shape[-2:]
assert len(input_shape) == len(other_shape)
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
# remove sharding on the broadcast dim
def _remove_sharding_on_broadcast_dim(key, strategy):
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
tensor_shape = list(sharding_spec.entire_shape)
for dim_idx, broadcast_type in self.broadcast_dim_info[key].items():
if broadcast_type == BroadcastType.MULTIPLE:
# if the dim is originally 1 and multiplied during broadcast
# we set its sharding to R
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
# the dim 0 of [1, 2, 4] is multiplied to 4
tensor_shape[dim_idx] = 1
elif broadcast_type == BroadcastType.PADDDING:
# if the dim is padded
# we remove its sharding
tensor_shape[dim_idx] = None
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape,
physical_shape=tensor_shape_before_broadcast)
strategy.sharding_specs[op_data] = physical_sharding_spec
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
_remove_sharding_on_broadcast_dim('input', strategy_copy)
_remove_sharding_on_broadcast_dim('other', strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
pass
return strategies
class Viewer(BmmTransform):
"""
Change the shape of the tensor from N-D to 3D
"""
def __init__(self) -> None:
self.batch_dims_before_view = None
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
self.batch_dims_before_view = list(mapping_copy['input'][:-2])
# get shapes
input_shape = shape_mapping['input']
other_shape = shape_mapping['other']
# view to 3d tensor
assert len(input_shape) >= 3 and len(other_shape) >= 3
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
output_shape = input_shape[:2] + other_shape[2:]
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
mapping_copy['output'] = output_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
# get operation data
def _update_sharding_spec(key, strategy, physical_batch_dim):
"""
Map the logical batch dim to the physical batch dim
"""
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
dim_partition_dict = sharding_spec.dim_partition_dict
entire_shape = sharding_spec.entire_shape
# upddate the dimension index for the matrix dimensions
if 2 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2)
if 1 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1)
# map the logical batch dim to phyiscal batch dim
if 0 in dim_partition_dict:
batch_dim_shard = dim_partition_dict.pop(0)
dim_partition_dict[physical_batch_dim] = batch_dim_shard
# the new shape will be the batch dims + the last 2 matrix dims
shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:])
sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict)
num_batch_dim_before_view = len(self.batch_dims_before_view)
# enumerate all sharding strategies
strategies = []
for i in range(num_batch_dim_before_view):
# create a new strategy
strategy_copy = strategy.clone()
try:
_update_sharding_spec('input', strategy_copy, i)
_update_sharding_spec('other', strategy_copy, i)
_update_sharding_spec('output', strategy_copy, i)
strategies.append(strategy_copy)
except ShardingSpecException as e:
continue
return strategies
def _get_bmm_logical_shape(input_shape, other_shape, transforms):
"""
Compute the logical shapes for BMM operation. BMM has a general representation
[b, i, k] = [b, i, j] x [b, j, k]
The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions
The logical shape for the bmm operands will undergo three stages
1. append/prepend the 1 to the 1D tensor if there is any
2. broadcast the non-matrix dimensions
3. reshape to 3 dimensions
"""
shape_mapping = {'input': input_shape, 'other': other_shape}
for transform in transforms:
shape_mapping = transform.apply(shape_mapping)
input_shape = shape_mapping.get('input', None)
other_shape = shape_mapping.get('other', None)
output_shape = shape_mapping.get('output', None)
return input_shape, other_shape, output_shape
@operator_registry.register(torch.matmul)
@operator_registry.register(torch.Tensor.matmul)
class MatMulHandler(NodeHandler):
"""
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
the operands.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# check which type of operation this matmul will call
self.input_meta_data = self.node.args[0]._meta_data
self.other_meta_data = self.node.args[1]._meta_data
self.output_meta_data = self.node._meta_data
input_dim = self.input_meta_data.dim()
other_dim = self.other_meta_data.dim()
self.matmul_type = get_matmul_type(input_dim, other_dim)
if self.matmul_type == MatMulType.BMM:
# bmm operation can possibly involve padding, broadcasting and view
# these transforms will be used to create logical shape and
# recover physical sharding spec
self.transforms = [Padder(), Broadcaster(), Viewer()]
else:
self.transforms = None
def get_strategy_generator(self) -> List[StrategyGenerator]:
generators = []
op_data_mapping = self.get_operation_data_mapping()
if self.matmul_type == MatMulType.BMM:
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.DOT:
generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MV:
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM:
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
logical_shape_func = {
MatMulType.DOT: self._get_logical_shape_for_dot,
MatMulType.MM: self._get_logical_shape_for_mm,
MatMulType.MV: self._get_logical_shape_for_mv,
MatMulType.BMM: self._get_logical_shape_for_bmm
}
logical_shapes = logical_shape_func[self.matmul_type]()
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
return op_data_mapping
def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape):
# convert list to torch.Size
if input_logical_shape:
input_logical_shape = torch.Size(input_logical_shape)
if other_logical_shape:
other_logical_shape = torch.Size(other_logical_shape)
if output_logical_shape:
output_logical_shape = torch.Size(output_logical_shape)
# create op data
input_op_data = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.input_meta_data,
logical_shape=input_logical_shape)
other_op_data = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.other_meta_data,
logical_shape=other_logical_shape)
output_op_data = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.output_meta_data,
logical_shape=output_logical_shape)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
def _get_logical_shape_for_dot(self):
"""
The operands for the dot operation have the same logical shape as the physical shape
"""
return None, None, None
def _get_logical_shape_for_mm(self):
"""
We need to handle the input tensor for a matrix-matrix multiplcation as the input
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
(e.g. [4] -> [1, 4]).
"""
if self.input_meta_data.dim() == 1:
input_logical_shape = [1] + list(self.input_meta_data.shape)
input_logical_shape = torch.Size(input_logical_shape)
else:
input_logical_shape = None
return input_logical_shape, None, None
def _get_logical_shape_for_mv(self):
"""
No broadcasting or dim insertion occurs for matrix-vector operation.
"""
return None, None, None
def _get_logical_shape_for_bmm(self):
input_physical_shape = list(self.input_meta_data.shape)
other_physical_shape = list(self.other_meta_data.shape)
return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms)
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
if self.matmul_type in [MatMulType.DOT, MatMulType.MV]:
return strategy
elif self.matmul_type == MatMulType.MM:
if self.input_meta_data.dim() == 1:
# if a 1 is prepended to the input shape (this occurs when input is a 1D tensor)
# we need to remove that dim
input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0]))
input_physical_shape = self.node.args[0]._meta_data.shape
dim_partition_dict = input_sharding_spec.dim_partition_dict
# remove the partitioning in the dim 0
if 0 in dim_partition_dict:
dim_partition_dict.pop(0, None)
# move the partitioning in dim 1 to dim 0
if -1 in dim_partition_dict:
shard = dim_partition_dict.pop(-1)
dim_partition_dict[0] = shard
# re-init the sharding spec
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
entire_shape=input_physical_shape,
dim_partition_dict=dim_partition_dict)
return strategy
else:
return strategy
elif self.matmul_type == MatMulType.BMM:
op_data_mapping = self.get_operation_data_mapping()
strategies = [strategy]
# recover the physical sharding spec
for transform in self.transforms[::-1]:
recovered_stragies = []
for strategy_ in strategies:
output = transform.recover(op_data_mapping, strategy_)
if isinstance(output, ShardingStrategy):
recovered_stragies.append(output)
elif isinstance(output, (list, tuple)):
recovered_stragies.extend(output)
else:
raise TypeError(
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
strategies = recovered_stragies
return strategies
...@@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): ...@@ -60,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0] fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = sharded_input_shape * 2 bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost, bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost) total=fwd_compute_cost + bwd_compute_cost)
return compute_cost return compute_cost
@ignore_sharding_exception
def no_split(self): def no_split(self):
name = f'R = R dot R' name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
...@@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): ...@@ -75,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_one_dim(self, mesh_dim): def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}' name = f'R = S{mesh_dim} dot S{mesh_dim}'
...@@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): ...@@ -93,7 +95,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
# do not split dimensions for dot product # do not split dimensions for dot product
...@@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): ...@@ -113,24 +115,50 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool: def validate(self) -> bool:
input_op_data = self.op_data['input'] input_op_data = self.op_data['input']
other_op_data = self.op_data['other'] other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1 assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self): def no_split(self):
name = "R = R x R" name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}} dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={}) communication_action_mapping={})
@ignore_sharding_exception
def split_input_batch(self, mesh_dim): def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R' name = f'S{mesh_dim}R = S{mesh_dim}R x R'
# get sharding spec # get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}} dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {},
"output": {
0: [mesh_dim]
},
}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action # get communication action
communication_action_mapping = {}
if self.is_param('other'): if self.is_param('other'):
other_comm_action = self.get_communication_action( other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'], sharding_spec=sharding_spec_mapping['other'],
...@@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): ...@@ -144,6 +172,8 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis=mesh_dim, logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE, comm_type=CommType.BEFORE,
arg_index=1) arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias: if self.has_bias:
if self.is_param('bias'): if self.is_param('bias'):
bias_comm_action = self.get_communication_action( bias_comm_action = self.get_communication_action(
...@@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): ...@@ -158,13 +188,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
logical_process_axis=mesh_dim, logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE, comm_type=CommType.BEFORE,
arg_index=2) arg_index=2)
communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action} communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
# no split # no split
...@@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -638,7 +668,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool: def validate(self) -> bool:
input_op_data = self.op_data['input'] input_op_data = self.op_data['input']
other_op_data = self.op_data['other'] other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3 assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
if 'bias' in self.op_data: if 'bias' in self.op_data:
bias_op_data = self.op_data['bias'] bias_op_data = self.op_data['bias']
...@@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -816,11 +846,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict = { dim_partition_dict = {
"input": { "input": {
0: [mesh_dim_0], 0: [mesh_dim_0],
-1: [mesh_dim_1] 2: [mesh_dim_1]
}, },
"other": { "other": {
0: [mesh_dim_0], 0: [mesh_dim_0],
-2: [mesh_dim_1] 1: [mesh_dim_1]
}, },
"bias": {}, "bias": {},
"output": { "output": {
......
...@@ -186,9 +186,14 @@ class StrategyGenerator(ABC): ...@@ -186,9 +186,14 @@ class StrategyGenerator(ABC):
""" """
op_data = self.op_data[key] op_data = self.op_data[key]
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device() sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
if len(sharded_shape) == 0:
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = self.op_data[key].data.dtype dtype = self.op_data[key].data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes return num_elements * size_per_elem_bytes
def generate(self) -> List[ShardingStrategy]: def generate(self) -> List[ShardingStrategy]:
""" """
......
...@@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: ...@@ -44,21 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
return dims[::-1] return dims[::-1]
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, def get_broadcast_dim_info(logical_shape, physical_shape):
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
# get the number of dimensions # get the number of dimensions
logical_num_dims = len(logical_shape) logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape) physical_num_dims = len(physical_shape)
...@@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe ...@@ -85,6 +71,31 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
else: else:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
return logical_dim_broadcast_info
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
# get the broadcast info
logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape)
# generate the sharding spec for the physical shape # generate the sharding spec for the physical shape
physical_dim_partition = {} physical_dim_partition = {}
logical_dim_partition = logical_sharding_spec.dim_partition_dict logical_dim_partition = logical_sharding_spec.dim_partition_dict
......
import operator import operator
from copy import deepcopy from copy import deepcopy
from enum import Enum
from functools import reduce from functools import reduce
import torch import torch
...@@ -175,6 +174,9 @@ class ShardingSpec: ...@@ -175,6 +174,9 @@ class ShardingSpec:
dim_partition_dict=None, dim_partition_dict=None,
sharding_sequence=None): sharding_sequence=None):
self.device_mesh = device_mesh self.device_mesh = device_mesh
if isinstance(entire_shape, (list, tuple)):
entire_shape = torch.Size(entire_shape)
self.entire_shape = entire_shape self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence self.sharding_sequence = sharding_sequence
......
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import (
MatMulHandler,
MatMulType,
_get_bmm_logical_shape,
get_matmul_type,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.utils import parameterize
class MatMulModule(nn.Module):
def forward(self, x1, x2):
return torch.matmul(x1, x2)
@parameterize(
'tensor_shapes',
[
[[8], [8]], # dot product
[[4, 8], [8]], # mat-vec product
[[4, 8], [8, 16]], # mat-mat product
[[8], [8, 16]], # mat-mat product
[[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting
[[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting
[[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting
[[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting
[[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting
[[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting
])
def test_matmul_node_handler(tensor_shapes):
input_shape, other_shape = tensor_shapes
# get output shape
x1 = torch.rand(*input_shape)
x2 = torch.rand(*other_shape)
output_shape = list(torch.matmul(x1, x2).shape)
# get matmul type
matmul_type = get_matmul_type(x1.dim(), x2.dim())
model = MatMulModule()
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
print(graph)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(mod_node)
# build handler
handler = MatMulHandler(node=mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
logical_input_shape = input_shape
logical_other_shape = other_shape
logical_output_shape = output_shape
if matmul_type == MatMulType.MM and len(input_shape) == 1:
logical_input_shape = [1] + input_shape
elif matmul_type == MatMulType.BMM:
logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape(
input_shape, other_shape, handler.transforms)
else:
logical_input_shape = input_shape
# check input operation data
assert mapping['input'].name == "x1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size(input_shape)
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size(logical_input_shape)
# check other operation data
assert mapping['other'].name == "x2"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size(other_shape)
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size(logical_other_shape)
# check output
assert mapping['output'].name == "matmul"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size(output_shape)
assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size(logical_output_shape)
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# ensure there is no duplicate strategy
if matmul_type != MatMulType.BMM:
assert len(set(strategy_name_list)) == len(strategy_name_list), strategy_name_list
for strategy in strategies_vector:
strategy: ShardingStrategy
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
output_sharding_spec = strategy.get_sharding_spec_by_name('matmul')
if matmul_type == MatMulType.DOT:
# dot product will produce a scaler
# results should fulfill:
# 1. the input and other operands have the same sharding spec
# 2. the output has no sharding
assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence
assert len(output_sharding_spec.sharding_sequence) == 0
elif matmul_type == MatMulType.MV:
# matrix-vector product should fulfill
# 1. the last dim of the input and other operands should have the same sharding
# 2. the first dim of the input and other should have the same sharding
# 3. the output should have only 1 dim
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert len(output_sharding_spec.sharding_sequence) == 1
elif matmul_type == MatMulType.MM:
# matrix-matrix multiplication should fulfil
# 1. if input is a 2D tensor, the 1st dim of input and output should have the same sharding
# 2. the input's last dim and the first dim of the other should have the same sharding
# 3. the last dim of the output and other should have the same sharding
# 4. the input and output should have the same number of dims
if len(input_shape) == 2:
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[0]
assert output_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
assert len(input_sharding_spec.sharding_sequence) == len(output_sharding_spec.sharding_sequence)
elif matmul_type == MatMulType.BMM:
# bmm should fulfil
# 1. of the other tensor is not a 1d tensor, the last dim of other and output have the same sharding
# 2. if the input has more than 2 dim, the second last dim of input and output have the same sharding
# 3. if the other have more than 2 dim, the second last dim of other and the last dim of input should have the same sharding
if len(other_shape) > 1:
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if len(input_shape) > 1:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
if len(other_shape) > 2:
assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1]
if __name__ == '__main__':
test_matmul_node_handler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment