Commit 08f2920e authored by zhuwenwen's avatar zhuwenwen
Browse files

init colossalai, support dtk2304

parent da3f0934
Pipeline #237 failed with stages
in 0 seconds
import operator
from abc import ABC, abstractmethod
from copy import deepcopy
from enum import Enum
from functools import reduce
from typing import Dict, List, Union
import torch
from colossalai.auto_parallel.tensor_shard.utils.broadcast import (
BroadcastType,
get_broadcast_dim_info,
get_broadcast_shape,
)
from colossalai.tensor.sharding_spec import ShardingSpecException
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
DotProductStrategyGenerator,
LinearProjectionStrategyGenerator,
MatVecStrategyGenerator,
StrategyGenerator,
)
class MatMulType(Enum):
"""
The MatMulType is categorized into 4 types based on the reference of torch.matmul
in https://pytorch.org/docs/stable/generated/torch.matmul.html.
DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements
MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
DOT = 0
MM = 1
MV = 2
BMM = 3
def get_matmul_type(input_dim: int, other_dim: int):
"""
Determine which type of matmul operation should be executed for the given tensor dimensions.
Args:
input_dim (int): the number of dimensions for the input tenosr
other_dim (int): the number of dimensions for the other tenosr
"""
if input_dim == 1 and other_dim == 1:
matmul_type = MatMulType.DOT
elif input_dim in [1, 2] and other_dim == 2:
matmul_type = MatMulType.MM
elif input_dim == 2 and other_dim == 1:
matmul_type = MatMulType.MV
elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2):
matmul_type = MatMulType.BMM
else:
raise ValueError(
f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation"
)
return matmul_type
class BmmTransform(ABC):
"""
BmmTransform is an abstraction of the shape conversion between logical and physical operation data
during the strategy generation.
"""
@abstractmethod
def apply(self, shape_mapping: Dict[str, List[int]]):
pass
@abstractmethod
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
pass
class Padder(BmmTransform):
"""
Add padding to the matrix dimensions for batched matrix multiplication.
"""
def __init__(self) -> None:
# keep the padding dim, op_name -> padded_dim
self.padded_dim_mapping = {}
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = deepcopy(shape_mapping)
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
if len(input_shape) == 1:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape.insert(0, 1)
self.padded_dim_mapping['input'] = -2
self.padded_dim_mapping['output'] = -2
elif len(other_shape) == 1:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape = other_shape.append(1)
self.padded_dim_mapping['other'] = -1
self.padded_dim_mapping['output'] = -1
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
input_op_data = op_data_mapping['input']
other_op_data = op_data_mapping['other']
def _remove_padded_dim(key, strategy):
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
tensor_shape = list(sharding_spec.entire_shape)
dim_partition_list = [None] * len(tensor_shape)
# padded dim is a negative number as the padded dim must be a matrix dim
padded_dim = self.padded_dim_mapping[key]
# compute the new dim partition
for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items():
dim_partition_list[tensor_dim] = mesh_dims
dim_partition_list.pop(padded_dim)
unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None}
# compute unpadded tensor shape
tensor_shape.pop(padded_dim)
assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
# update sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
# only one of input and other will be padded
if 'input' in self.padded_dim_mapping:
_remove_padded_dim('input', strategy_copy)
_remove_padded_dim('output', strategy_copy)
elif 'other' in self.padded_dim_mapping:
_remove_padded_dim('other', strategy_copy)
_remove_padded_dim('output', strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
pass
return strategies
class Broadcaster(BmmTransform):
"""
Broadcast the non-matrix dimensions for batched matrix multiplication.
"""
def __init__(self) -> None:
self.broadcast_dim_info = {}
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
# get shapes
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
# sanity check
assert len(input_shape) > 1 and len(other_shape) > 1
# broadcast the batch dim and record
bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2])
# store the broadcast dim info
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
self.broadcast_dim_info['input'] = input_broadcast_dim_info
self.broadcast_dim_info['other'] = other_broadcast_dim_info
# create the full logical shape
input_shape = bcast_non_matrix_dims + input_shape[-2:]
other_shape = bcast_non_matrix_dims + other_shape[-2:]
assert len(input_shape) == len(other_shape)
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
# remove sharding on the broadcast dim
def _remove_sharding_on_broadcast_dim(key, strategy):
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
tensor_shape = list(sharding_spec.entire_shape)
for dim_idx, broadcast_type in self.broadcast_dim_info[key].items():
if broadcast_type == BroadcastType.MULTIPLE:
# if the dim is originally 1 and multiplied during broadcast
# we set its sharding to R
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
# the dim 0 of [1, 2, 4] is multiplied to 4
tensor_shape[dim_idx] = 1
elif broadcast_type == BroadcastType.PADDDING:
# if the dim is padded
# we remove its sharding
tensor_shape[dim_idx] = None
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape,
physical_shape=tensor_shape_before_broadcast)
strategy.sharding_specs[op_data] = physical_sharding_spec
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
_remove_sharding_on_broadcast_dim('input', strategy_copy)
_remove_sharding_on_broadcast_dim('other', strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
pass
return strategies
class Viewer(BmmTransform):
"""
Change the shape of the tensor from N-D to 3D
"""
def __init__(self) -> None:
self.batch_dims_before_view = None
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
self.batch_dims_before_view = list(mapping_copy['input'][:-2])
# get shapes
input_shape = shape_mapping['input']
other_shape = shape_mapping['other']
# view to 3d tensor
assert len(input_shape) >= 3 and len(other_shape) >= 3
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
output_shape = input_shape[:2] + other_shape[2:]
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
mapping_copy['output'] = output_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
# get operation data
def _update_sharding_spec(key, strategy, physical_batch_dim):
"""
Map the logical batch dim to the physical batch dim
"""
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
dim_partition_dict = sharding_spec.dim_partition_dict
entire_shape = sharding_spec.entire_shape
# upddate the dimension index for the matrix dimensions
if 2 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2)
if 1 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1)
# map the logical batch dim to phyiscal batch dim
if 0 in dim_partition_dict:
batch_dim_shard = dim_partition_dict.pop(0)
dim_partition_dict[physical_batch_dim] = batch_dim_shard
# the new shape will be the batch dims + the last 2 matrix dims
shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:])
sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict)
num_batch_dim_before_view = len(self.batch_dims_before_view)
# enumerate all sharding strategies
strategies = []
for i in range(num_batch_dim_before_view):
# create a new strategy
strategy_copy = strategy.clone()
try:
_update_sharding_spec('input', strategy_copy, i)
_update_sharding_spec('other', strategy_copy, i)
_update_sharding_spec('output', strategy_copy, i)
strategies.append(strategy_copy)
except ShardingSpecException as e:
continue
return strategies
def _get_bmm_logical_shape(input_shape, other_shape, transforms):
"""
Compute the logical shapes for BMM operation. BMM has a general representation
[b, i, k] = [b, i, j] x [b, j, k]
The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions
The logical shape for the bmm operands will undergo three stages
1. append/prepend the 1 to the 1D tensor if there is any
2. broadcast the non-matrix dimensions
3. reshape to 3 dimensions
"""
shape_mapping = {'input': input_shape, 'other': other_shape}
for transform in transforms:
shape_mapping = transform.apply(shape_mapping)
input_shape = shape_mapping.get('input', None)
other_shape = shape_mapping.get('other', None)
output_shape = shape_mapping.get('output', None)
return input_shape, other_shape, output_shape
@operator_registry.register(torch.matmul)
@operator_registry.register(torch.Tensor.matmul)
class MatMulHandler(NodeHandler):
"""
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
the operands.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# check which type of operation this matmul will call
self.input_meta_data = self.node.args[0]._meta_data
self.other_meta_data = self.node.args[1]._meta_data
self.output_meta_data = self.node._meta_data
input_dim = self.input_meta_data.dim()
other_dim = self.other_meta_data.dim()
self.matmul_type = get_matmul_type(input_dim, other_dim)
if self.matmul_type == MatMulType.BMM:
# bmm operation can possibly involve padding, broadcasting and view
# these transforms will be used to create logical shape and
# recover physical sharding spec
self.transforms = [Padder(), Broadcaster(), Viewer()]
else:
self.transforms = None
def get_strategy_generator(self) -> List[StrategyGenerator]:
generators = []
op_data_mapping = self.get_operation_data_mapping()
if self.matmul_type == MatMulType.BMM:
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.DOT:
generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MV:
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM:
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
logical_shape_func = {
MatMulType.DOT: self._get_logical_shape_for_dot,
MatMulType.MM: self._get_logical_shape_for_mm,
MatMulType.MV: self._get_logical_shape_for_mv,
MatMulType.BMM: self._get_logical_shape_for_bmm
}
logical_shapes = logical_shape_func[self.matmul_type]()
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
return op_data_mapping
def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape):
# convert list to torch.Size
if input_logical_shape:
input_logical_shape = torch.Size(input_logical_shape)
if other_logical_shape:
other_logical_shape = torch.Size(other_logical_shape)
if output_logical_shape:
output_logical_shape = torch.Size(output_logical_shape)
# create op data
input_op_data = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.input_meta_data,
logical_shape=input_logical_shape)
other_op_data = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.other_meta_data,
logical_shape=other_logical_shape)
output_op_data = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.output_meta_data,
logical_shape=output_logical_shape)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
def _get_logical_shape_for_dot(self):
"""
The operands for the dot operation have the same logical shape as the physical shape
"""
return None, None, None
def _get_logical_shape_for_mm(self):
"""
We need to handle the input tensor for a matrix-matrix multiplcation as the input
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
(e.g. [4] -> [1, 4]).
"""
if self.input_meta_data.dim() == 1:
input_logical_shape = [1] + list(self.input_meta_data.shape)
input_logical_shape = torch.Size(input_logical_shape)
else:
input_logical_shape = None
return input_logical_shape, None, None
def _get_logical_shape_for_mv(self):
"""
No broadcasting or dim insertion occurs for matrix-vector operation.
"""
return None, None, None
def _get_logical_shape_for_bmm(self):
input_physical_shape = list(self.input_meta_data.shape)
other_physical_shape = list(self.other_meta_data.shape)
return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms)
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
if self.matmul_type in [MatMulType.DOT, MatMulType.MV]:
return strategy
elif self.matmul_type == MatMulType.MM:
if self.input_meta_data.dim() == 1:
# if a 1 is prepended to the input shape (this occurs when input is a 1D tensor)
# we need to remove that dim
input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0]))
input_physical_shape = self.node.args[0]._meta_data.shape
dim_partition_dict = input_sharding_spec.dim_partition_dict
# remove the partitioning in the dim 0
if 0 in dim_partition_dict:
dim_partition_dict.pop(0, None)
# move the partitioning in dim 1 to dim 0
if -1 in dim_partition_dict:
shard = dim_partition_dict.pop(-1)
dim_partition_dict[0] = shard
if 1 in dim_partition_dict:
shard = dim_partition_dict.pop(1)
dim_partition_dict[0] = shard
# re-init the sharding spec
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
entire_shape=input_physical_shape,
dim_partition_dict=dim_partition_dict)
return strategy
else:
return strategy
elif self.matmul_type == MatMulType.BMM:
op_data_mapping = self.get_operation_data_mapping()
strategies = [strategy]
# recover the physical sharding spec
for transform in self.transforms[::-1]:
recovered_stragies = []
for strategy_ in strategies:
output = transform.recover(op_data_mapping, strategy_)
if isinstance(output, ShardingStrategy):
recovered_stragies.append(output)
elif isinstance(output, (list, tuple)):
recovered_stragies.extend(output)
else:
raise TypeError(
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
strategies = recovered_stragies
return strategies
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingSpec,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from .strategy import StrategyGenerator
class NodeHandler(ABC):
'''
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
Args:
node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
'''
def __init__(
self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
"""
Compute the resharding costs and save the costs in the ShardingStrategy object.
"""
# TODO: test this function when other handlers are ready
resharding_costs = {}
shape_consistency_manager = ShapeConsistencyManager()
for node in self.predecessor_node:
node_name = str(node)
# get the current sharding spec generated by this node handler
# we will not compute the resharding costs for the node not counted in the strategy.
# And the node with tuple or list output need to be handled below.
node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()]
if str(node) not in node_in_strategy:
continue
op_data = strategy.get_op_data_by_name(node_name)
current_sharding_spec = strategy.sharding_specs[op_data]
# get the sharding specs for this node generated
# in its own node handler
assert hasattr(node, 'strategies_vector'), \
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
prev_strategy_vector = node.strategies_vector
prev_sharding_specs = [
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
]
# create data structrure to store costs
if node not in resharding_costs:
resharding_costs[node] = []
def _compute_resharding_cost(
prev_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]],
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
"""
This is a helper function to compute the resharding cost for a specific strategy of a node.
"""
if prev_sharding_spec is None:
return TrainCycleItem(fwd=0, bwd=0, total=0)
elif isinstance(prev_sharding_spec, ShardingSpec):
if isinstance(data, torch.nn.parameter.Parameter):
# we won't compute the resharding cost for the parameters,
# since the parameters will be sharded before runtime and
# not converted during runtime.
return TrainCycleItem(fwd=0, bwd=0, total=0)
elif isinstance(data, torch.Tensor):
dtype = data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
prev_sharding_spec, current_sharding_spec)
resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
bwd=consistency_cost["backward"] * size_per_elem_bytes,
total=consistency_cost["total"] * size_per_elem_bytes)
return resharding_cost
else:
# This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments.
raise ValueError(f'Unsupported data type {type(data)}')
else:
assert isinstance(prev_sharding_spec, (tuple, list)), \
f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
fwd_cost = 0
bwd_cost = 0
total_cost = 0
for index, (prev_sharding_spec_item,
current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
current_sharding_spec)):
item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
data[index])
fwd_cost += item_cost.fwd
bwd_cost += item_cost.bwd
total_cost += item_cost.total
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost)
return resharding_cost
# for each sharding spec generated by the predecessor's node handler
# compute the resharding cost to switch to the sharding spec generated
# by the current node handler
for prev_sharding_spec in prev_sharding_specs:
resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data)
resharding_costs[node].append(resharding_cost)
strategy.resharding_costs = resharding_costs
return strategy
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
"""
Register different sharding strategies for the current node.
"""
strategy_generators = self.get_strategy_generator()
for generator in strategy_generators:
strategies = generator.generate()
# postprocess a strategy
# postprocess can produce one strategy or multiple strategies
post_processed_strategies_map = map(self.post_process, strategies)
post_processed_strategies = []
for strategy in post_processed_strategies_map:
if isinstance(strategy, (list, tuple)):
post_processed_strategies.extend(strategy)
else:
post_processed_strategies.append(strategy)
# compute the resharding costs based on the previous node
# strategies if specified
if compute_resharding_cost:
updated_strategies = map(self.update_resharding_cost, post_processed_strategies)
post_processed_strategies = list(updated_strategies)
self.strategies_vector.extend(post_processed_strategies)
# validating the correctness of the sharding strategy
for strategy in self.strategies_vector:
for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
check_sharding_spec_validity(sharding_spec, op_data.data)
return self.strategies_vector
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# tranform the strategy generated
# e.g. to process the sharding strategy for the transposed weights
return strategy
@abstractmethod
def get_strategy_generator(self) -> List[StrategyGenerator]:
"""
Define which generators should be used by this NodeHandler object.
"""
pass
@abstractmethod
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
"""
Returns the mapping between the logical operation data to its physical data.
A logical operation data is a data associated with an operation, which can be input and output. It is
defined by the strategy generator, for example, a matrix multiplication operation has two operands "input"
and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is
the module input, the physical operand for "other" is the module weight, and the physical result for "output"
is the module output.
Note that the operand name is specified by the StrategyGenerator object.
For example:
# for a linear layer
mapping = {
"input": Operand(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data),
"other": Operand(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight']),
"bias": Operand(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']),
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
}
"""
pass
class ModuleHandler(NodeHandler):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# set attributes to access module parameters for convenience
assert self.node.graph.owning_module is not None, \
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
module = self.node.graph.owning_module.get_submodule(self.node.target)
named_parameters = list(module.named_parameters(recurse=False))
named_buffers = list(module.named_buffers(recurse=False))
# convert named parameters from list to dict
named_parameters = {k: v for k, v in named_parameters}
named_buffers = {k: v for k, v in named_buffers}
self.module = module
self.named_parameters = named_parameters
self.named_buffers = named_buffers
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
__all__ = ['NormPoolingHandler']
@operator_registry.register(torch.nn.MaxPool1d)
@operator_registry.register(torch.nn.MaxPool2d)
@operator_registry.register(torch.nn.MaxPool1d)
@operator_registry.register(torch.nn.AvgPool1d)
@operator_registry.register(torch.nn.AvgPool2d)
@operator_registry.register(torch.nn.AvgPool3d)
class NormPoolingHandler(ModuleHandler):
"""
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(NormalPoolStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_weight_operand, "output": physical_output}
return mapping
from typing import Dict, List
import torch
from colossalai.device.device_mesh import DeviceMesh
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
from .node_handler import NodeHandler
from .strategy import OutputGenerator, StrategyGenerator
__all__ = ['OuputHandler']
class OuputHandler(NodeHandler):
"""
A OuputHandler which deals with the sharding strategies for Output Node.
"""
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
output_option: str) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.output_option = output_option
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node, self.output_option))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
mapping = {}
output_meta_data = []
for index, input_node in enumerate(self.predecessor_node):
input_meta_data = input_node._meta_data
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
name_key = f'input_{index}'
mapping[name_key] = physical_inputs
output_meta_data.append(input_meta_data)
assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.'
if len(output_meta_data) == 1:
output_meta_data = output_meta_data[0]
else:
output_meta_data = tuple(output_meta_data)
self.node._meta_data = output_meta_data
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping["output"] = physical_output
return mapping
from typing import Dict, List
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
from .node_handler import NodeHandler
from .strategy import PlaceholderGenerator, StrategyGenerator
__all__ = ['PlacehodlerHandler']
class PlacehodlerHandler(NodeHandler):
"""
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node.
"""
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
placeholder_option: str) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.placeholder_option = placeholder_option
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"output": physical_output}
return mapping
class Registry:
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
for element in source:
self.store[element] = func
else:
self.store[source] = func
return func
return wrapper
def get(self, source):
assert source in self.store, f'{source} not found in the {self.name} registry'
target = self.store[source]
return target
def has(self, source):
return source in self.store
operator_registry = Registry('operator')
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import ReshapeGenerator, StrategyGenerator
__all__ = ['ReshapeHandler']
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def infer_logical_shape(self, data):
"""
This function is used to infer logical shape for operands.
Notes: This function is only used for the operands whose data are not only in type of tensor,
such as tuple of tensor.
"""
if isinstance(data, torch.Tensor):
return data.shape
else:
assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor."
logical_shape = []
for tensor in data:
assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor."
logical_shape.append(tensor.shape)
logical_shape = tuple(logical_shape)
return logical_shape
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
input_logical_shape = self.infer_logical_shape(input_data)
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=data_type,
data=input_data,
logical_shape=input_logical_shape)
output_data = self.node._meta_data
output_logical_shape = self.infer_logical_shape(output_data)
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_data,
logical_shape=output_logical_shape)
mapping = {"input": physical_input_operand, "output": physical_output}
return mapping
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import SoftmaxGenerator, StrategyGenerator
__all__ = ['SoftmaxHandler']
@operator_registry.register(torch.nn.Softmax)
@operator_registry.register(torch.nn.functional.softmax)
class SoftmaxHandler(NodeHandler):
"""
A SoftmaxHandler which deals with the sharding strategies for
torch.nn.Softmax or torch.nn.functional.softmax.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SoftmaxGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
softmax_dim = self.node.kwargs['dim']
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if softmax_dim < 0:
softmax_dim += num_dims
physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"softmax_dim": physical_dim_operand,
"output": physical_output_operand
}
return mapping
from .batch_norm_generator import BatchNormStrategyGenerator
from .binary_elementwise_generator import BinaryElementwiseStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
from .embedding_generator import EmbeddingStrategyGenerator
from .getattr_generator import GetattrGenerator
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
from .layer_norm_generator import LayerNormGenerator
from .matmul_strategy_generator import (
BatchedMatMulStrategyGenerator,
DotProductStrategyGenerator,
LinearProjectionStrategyGenerator,
MatVecStrategyGenerator,
)
from .normal_pooling_generator import NormalPoolStrategyGenerator
from .output_generator import OutputGenerator
from .placeholder_generator import PlaceholderGenerator
from .reshape_generator import ReshapeGenerator
from .softmax_generator import SoftmaxGenerator
from .strategy_generator import StrategyGenerator
from .sum_generator import SumGenerator
from .tensor_constructor_generator import TensorConstructorGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator
from .where_generator import WhereGenerator
__all__ = [
'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator',
'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator',
'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator'
]
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
__all__ = ['BatchNormStrategyGenerator']
class BatchNormStrategyGenerator(StrategyGenerator):
"""
A StrategyGenerator which deals with the sharding strategies of batch normalization.
To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
us to keep the computing correctness.
In this generator, both methods will be considered.
"""
def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
input_product = reduce(operator.mul, sharded_input_shape, 1)
forward_compute_cost = input_product
backward_activation_compute_cost = input_product
backward_weight_compute_cost = input_product
backward_compute_cost = backward_weight_compute_cost + backward_activation_compute_cost
if self.has_bias:
forward_compute_cost += bias_compute_cost
backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output"),
'running_mean': self._compute_size_in_bytes(strategy, "running_mean"),
'running_var': self._compute_size_in_bytes(strategy, "running_var"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum(
[v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum(
[v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost,
buffer=fwd_buffer_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0]
},
"output": {
1: [mesh_dim_0]
},
"running_mean": {
0: [mesh_dim_0]
},
"running_var": {
0: [mesh_dim_0]
},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"output": {
1: [mesh_dim_0, mesh_dim_1]
},
"running_mean": {
0: [mesh_dim_0, mesh_dim_1]
},
"running_var": {
0: [mesh_dim_0, mesh_dim_1]
},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_mapping = {
"input": {},
"other": {},
"output": {},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {},
"output": {
0: [mesh_dim_0]
},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
# For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.IMPLICIT)
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.IMPLICIT)
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
"other": {
0: [mesh_dim_1],
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
"running_mean": {
0: [mesh_dim_1],
},
"running_var": {
0: [mesh_dim_1],
},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {
0: [mesh_dim_1],
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0],
comm_type=CommType.IMPLICIT)
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
'''
strategy_list = []
# RS = RS x S
strategy_list.append(self.split_input_channel(0))
strategy_list.append(self.split_input_channel(1))
# RR = RR x R
strategy_list.append(self.non_split())
# RS01 = RS01 x S01
strategy_list.append(self.split_input_channel_1d(0, 1))
# The strategies with SYNC_BN are temporarily commented,
# because it requires some additional passes to keep runtime
# computation correctness.
# TODO: The strategies below should be uncommented after runtime
# passes ready.
# SR = SR x R WITH SYNC_BN
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
# SS = SS x S WITH SYNC_BN
strategy_list.append(self.split_input_both_dim(0, 1))
strategy_list.append(self.split_input_both_dim(1, 0))
# S01R = S01R x R WITH SYNC_BN
strategy_list.append(self.split_input_batch_1d(0, 1))
return strategy_list
import operator
from functools import reduce
from typing import List
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import StrategyGenerator
__all__ = ['BinaryElementwiseStrategyGenerator']
class BinaryElementwiseStrategyGenerator(StrategyGenerator):
"""
An BinaryElementwiseStrategyGenerator is a node handler which deals with elementwise operations
which have two operands and broadcasting occurs such as torch.add.
The logical shape for this operation will be `input <op> other`.
"""
def validate(self) -> bool:
assert len(self.op_data) == 3, \
f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}'
for name, op_data in self.op_data.items():
if not isinstance(op_data.data, (torch.Tensor, int, float)):
raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.')
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
# since elementwise ops are not compute-intensive,
# we approximate the backward compute cost
# to be twice the fwd compute cost
fwd_compute_cost = reduce(operator.mul, shape)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# all input, output and outputs have the same shape
shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
# compute fwd memory cost in bytes
# as the elementwise ops are not memory-intensive
# we approximate the fwd memroy cost to be the output
# and the backward memory cost to be grad of input and other
input_bytes = self._compute_size_in_bytes(strategy, 'input')
other_bytes = self._compute_size_in_bytes(strategy, 'other')
output_bytes = self._compute_size_in_bytes(strategy, 'output')
fwd_memory_cost = MemoryCost(activation=output_bytes)
bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes)
total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
dim_size = len(self.op_data['output'].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
dim_partition_list.extend(sharding_list_2d)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
dim_partition_list.extend(sharding_list_1d_on_dim_0)
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
dim_partition_list.extend(sharding_list_1d_on_dim_1)
# add empty dict for fully replicated case
dim_partition_list.append({})
# sharding strategy bookkeeping
strategy_list = []
# convert these dim partition dict to sharding strategy
for dim_partition_dict in dim_partition_list:
dim_partition_dict_mapping = dict(input=dim_partition_dict,
other=dim_partition_dict,
output=dim_partition_dict)
try:
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
# get name
sharding_seq = sharding_spec_mapping['input'].sharding_sequence
name = f'{sharding_seq} = {sharding_seq} <binary-elementwise-op> {sharding_seq}'
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
return strategy_list
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = self.enumerate_all_possible_output(0, 1)
return strategy_list
import copy
import operator
import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
class ConvStrategyGenerator(StrategyGenerator):
"""
ConvStrategyGenerator is a generic class to generate strategies.
The operation data is defined as `output = input x other + bias`.
"""
def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
output_size = sharded_output_shape[2:]
output_size_product = reduce(operator.mul, output_size)
input_size = sharded_input_shape[2:]
input_size_product = reduce(operator.mul, input_size, 1)
kernel_size = sharded_other_shape[2:]
kernel_size_product = reduce(operator.mul, kernel_size, 1)
batch_size = sharded_input_shape[0]
channel_in = sharded_input_shape[1]
channel_out = sharded_other_shape[1]
forward_compute_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product
backward_activation_cost = input_size_product * batch_size * channel_in * channel_out * kernel_size_product
backward_weight_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product
backward_compute_cost = backward_weight_cost + backward_activation_cost
if self.has_bias:
forward_compute_cost += bias_compute_cost
backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {
1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {},
"output": {
0: [mesh_dim_0],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
"other": {
0: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0],
},
"other": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
"output": {
1: [mesh_dim_1],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {
0: [mesh_dim_1],
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"output": output_comm_action, "input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0],
},
"other": {
0: [mesh_dim_0],
},
"output": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
dim_partition_dict_mapping = {
"input": {},
"other": {
1: [mesh_dim_0],
},
"output": {
1: [mesh_dim_0],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {
0: [mesh_dim_0],
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
dim_partition_dict_mapping = {
"input": {},
"other": {},
"output": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1],
},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1],
},
"other": {
0: [mesh_dim_0, mesh_dim_1],
},
"output": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {},
"other": {
1: [mesh_dim_0, mesh_dim_1],
},
"output": {
1: [mesh_dim_0, mesh_dim_1],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {
0: [mesh_dim_0, mesh_dim_1],
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
# SS = SR x RS
strategies.append(self.split_input_batch_weight_out_channel(0, 1))
strategies.append(self.split_input_batch_weight_out_channel(1, 0))
# SR = SR x RR
strategies.append(self.split_input_batch(0))
strategies.append(self.split_input_batch(1))
# SR = SS x SR
strategies.append(self.split_input_both_dim_weight_in_channel(0, 1))
strategies.append(self.split_input_both_dim_weight_in_channel(1, 0))
# RS = RS x SS
strategies.append(self.split_input_in_channel_weight_both_channel(0, 1))
strategies.append(self.split_input_in_channel_weight_both_channel(1, 0))
# RR = RS x SR
strategies.append(self.split_input_in_channel_weight_in_channel(0))
strategies.append(self.split_input_in_channel_weight_in_channel(1))
# RS = RR x RS
strategies.append(self.split_weight_out_channel(0))
strategies.append(self.split_weight_out_channel(1))
# RR= RR x RR
strategies.append(self.non_split())
# S01R = S01R x RR
strategies.append(self.split_1d_parallel_on_input_batch(0, 1))
# RR = RS01 x S01R
strategies.append(self.split_1d_parallel_on_in_channel(0, 1))
# RS01 = RR x RS01
strategies.append(self.split_1d_parallel_on_out_channel(0, 1))
return strategies
import copy
import operator
import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
class EmbeddingStrategyGenerator(StrategyGenerator):
"""
EmbeddingStrategyGenerator is a generic class to generate strategies for nn.Embedding or F.embedding.
The operation data is defined as `output = input x other`.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
Compute the computation cost per device with this specific strategy.
Note: The computation cost for the embedding handler is estimated as dense computing now.
It may not be accurate.
'''
# TODO: estimate the embedding computation cost as sparse operation
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
other_size_product = reduce(operator.mul, sharded_other_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
forward_compute_cost = input_size_product * other_size_product
backward_activation_cost = other_size_product * output_size_product / sharded_output_shape[-1]
backward_weight_cost = input_size_product * other_size_product
backward_compute_cost = backward_weight_cost + backward_activation_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def non_split(self):
name = f'RR = R x RR'
dim_partition_dict_mapping = {
"input": {},
"other": {},
"output": {},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@ignore_sharding_exception
def split_input(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {},
"output": {
0: [mesh_dim_0],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
},
"other": {
1: [mesh_dim_1],
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
communication_action_mapping = {}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_embedding_dim(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}'
dim_partition_dict_mapping = {
"input": {},
"other": {
1: [mesh_dim_0],
},
"output": {
1: [mesh_dim_0],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {},
"other": {
1: [mesh_dim_0, mesh_dim_1],
},
"output": {
1: [mesh_dim_0, mesh_dim_1],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
# RR= R x RR
strategies.append(self.non_split())
# SR = S x RR
strategies.append(self.split_input(0))
strategies.append(self.split_input(1))
# SS = S x RS
strategies.append(self.split_input_and_embedding_dim(0, 1))
strategies.append(self.split_input_and_embedding_dim(1, 0))
# S01R = S01 x RR
strategies.append(self.split_1d_parallel_on_input(0, 1))
# RS = R x RS
strategies.append(self.split_embedding_dim(0))
strategies.append(self.split_embedding_dim(1))
# RS01 = R x RS01
strategies.append(self.split_1d_parallel_on_embedding_dim(0, 1))
return strategies
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from .strategy_generator import StrategyGenerator
__all__ = ['GetattrGenerator']
class GetattrGenerator(StrategyGenerator):
"""
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
dim_partition_dict_mapping = {
"output": {},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Attribute'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return [strategy]
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']
class GetItemStrategyGenerator(FollowingStrategyGenerator):
"""
GetItemStrategyGenerator is a generic class to generate strategies for operator.getitem.
The operation data is defined as `output = input[other]`.
There are mainly three use cases:
1. args_0._meta_data: torch.Tensor, args_1._meta_data: int
2. args_0._meta_data: torch.Tensor, args_1._meta_data: slice
3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class TensorStrategyGenerator(GetItemStrategyGenerator):
'''
Deal with case 1 and 2.
'''
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
gather_input = 0 in dim_partition_dict_for_input
if gather_input:
logical_process_axis = dim_partition_dict_for_output.pop(0)
shift_dim_partition_dict_for_output = {}
for dim, mesh_dim_list in dim_partition_dict_for_output.items():
shift_dim_partition_dict_for_output[dim - 1] = mesh_dim_list
dim_partition_dict_for_output = shift_dim_partition_dict_for_output
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
if gather_input:
input_communication_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=logical_process_axis,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping["input"] = input_communication_action
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list
class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
'''
Deal with case 3.
'''
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
index = self.op_data["index"].data
for strategy_index, strategy in enumerate(self.predecessor_node.strategies_vector):
# the sharding spec for input in this case is a tuple of ShardingSpec.
sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_mapping = {
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
sharding_spec_mapping["input"] = sharding_spec_for_input
input_sharding_info = f"get the {index} element from ("
for sharding_spec in sharding_spec_for_input:
input_sharding_info += f'{sharding_spec.sharding_sequence}, '
input_sharding_info += ")"
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
__all__ = ['LayerNormGenerator']
class LayerNormGenerator(StrategyGenerator):
"""
LayerNormGenerator is a generic class to generate strategies for LayerNorm operation.
The operation data is defined as `output = input x other + bias`.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_weight_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)]
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1)
forward_compute_cost = input_batch_product * norm_kernel_product
backward_activation_compute_cost = input_batch_product * norm_kernel_product
# To compute gradient of on norm kernel element requires input_batch_product times computation, so
# the total cost is input_batch_product * norm_kernel_product
backward_weight_compute_cost = input_batch_product * norm_kernel_product
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
if self.has_bias:
forward_compute_cost += bias_compute_cost
backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {
"input": dim_partition,
"other": {},
"output": dim_partition,
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence} x {sharding_spec_mapping["other"].sharding_sequence}'
total_mesh_dim_list = []
for mesh_dim_list in dim_partition.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
communication_action_mapping = {}
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def split_input_batch_single_mesh_dim(self, mesh_dim_0, batch_dimension_length):
strategy_list = []
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
return strategy_list
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1, batch_dimension_length):
strategy_list = []
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
return strategy_list
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_mapping = {
"input": {},
"other": {},
"output": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
'''
strategy_list = []
input_data_dim = len(self.op_data["input"].logical_shape)
weight_data_dim = len(self.op_data["other"].logical_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
batch_dimension_length = input_data_dim - weight_data_dim
# SR = SR x R with single mesh dim on batch dimensions
strategy_list.extend(self.split_input_batch_single_mesh_dim(0, batch_dimension_length))
strategy_list.extend(self.split_input_batch_single_mesh_dim(1, batch_dimension_length))
# SR = SR x R with both mesh dims on batch dimensions
strategy_list.extend(self.split_input_batch_both_mesh_dim(0, 1, batch_dimension_length))
# RR = RR x R
strategy_list.append(self.non_split())
return strategy_list
import operator
from ast import arg
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
class MatMulStrategyGenerator(StrategyGenerator):
"""
MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases.
The operation data is defined as `output = input x other + bias`.
"""
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
size_mapping['bias'] = bias_size
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + bias_grad
bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + 0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class DotProductStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}'
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# do not split dimensions for dot product
# R = R dot R
strategy_list.append(self.no_split())
# split two tensors in the same dimensions
# S = S dot S
strategy_list.append(self.split_one_dim(0))
strategy_list.append(self.split_one_dim(1))
return strategy_list
class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@ignore_sharding_exception
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
# get sharding spec
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {},
"output": {
0: [mesh_dim]
},
}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
communication_action_mapping = {}
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=2)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# no split
strategy_list.append(self.no_split())
# split the batch dim for the first tensor only
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
return strategy_list
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C = AB
# C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul)
# bwd: 2 x fwd_cost
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
dim_m_val = reduce(operator.mul, sharded_input_shape[:-1])
dim_n_val = sharded_other_shape[-1]
dim_p_val = sharded_other_shape[0]
fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=bwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
# SS = SR x RS
strategies.append(self.split_lhs_space_rhs_space(0, 1))
strategies.append(self.split_lhs_space_rhs_space(1, 0))
# SR = SS x SR
strategies.append(self.split_lhs_space_both_contract(0, 1))
strategies.append(self.split_lhs_space_both_contract(1, 0))
# RS = RS x SS
strategies.append(self.split_rhs_space_both_contract(0, 1))
strategies.append(self.split_rhs_space_both_contract(1, 0))
# RR= RS x SR
strategies.append(self.recompute_split_both_contract(0))
strategies.append(self.recompute_split_both_contract(1))
# RS = RR x RS
strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1))
# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
# RR = RS01 x S01R
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# RR = RR x RR
strategies.append(self.non_split())
return strategies
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['input'] = input_comm_action
communication_action_mapping['other'] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
# get sharding spec mapping
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action mapping
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
communication_action_mapping['output'] = output_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
# get sharding specs
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
-1: [mesh_dim_1]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping["input"] = input_comm_action
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim]
},
"other": {
0: [mesh_dim]
},
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
-1: [mesh_dim]
},
"bias": {
-1: [mesh_dim]
},
"output": {
-1: [mesh_dim]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
-1: [mesh_dim_0, mesh_dim_1]
},
"bias": {
-1: [mesh_dim_0, mesh_dim_1]
},
"output": {
-1: [mesh_dim_0, mesh_dim_1]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {},
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def validate(self) -> bool:
assert "input" in self.op_data
assert "other" in self.op_data
# make sure the other has 2 dim
input_data = self.op_data['input']
other_data = self.op_data['other']
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
if self.has_bias:
bias_data = self.op_data['bias']
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
"""
Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
The bias term is considered to have a 2D logical shape.
Note: This class will be used to generate strategies for torch.bmm
and torch.addbmm. However, the result of torch.addbmm is not correct,
some extra runtime apply actions are required to keep numerical correctness.
"""
# TODO: torch.addbmm correctness issue need to be fixed.
def __init__(self, *args, **kwargs):
self.squeeze_batch_dim = False
super().__init__(*args, **kwargs)
def _pop_batch_dim_sharding_for_output(self, dim_partition_dict):
# remove partition dict for dim 0
dim_partition_dict['output'].pop(0, None)
# decrease the remaining dim index by 1
temp_dim_partition = {}
keys = list(dim_partition_dict['output'].keys())
for key in keys:
val = dim_partition_dict['output'].pop(key)
temp_dim_partition[key - 1] = val
dim_partition_dict['output'].update(temp_dim_partition)
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
if 'bias' in self.op_data:
bias_op_data = self.op_data['bias']
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
if self.op_data['output'].data.dim() == 2:
# addbmm will shrink the first batch dim
self.squeeze_batch_dim = True
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
self.op_data['output'].data.shape)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
@ignore_sharding_exception
def split_one_batch_dim(self, mesh_dim):
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
# get sharding_spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
}
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0]
},
"bias": {
0: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1]
}
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
# for addbmm case, other is the third argument instead of second.
communication_action_mapping['other'].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
2: [mesh_dim_1]
},
"bias": {
1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
2: [mesh_dim_1]
}
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
communication_action_mapping['bias'] = bias_comm_action
# for addbmm case, other is the second argument instead of first.
communication_action_mapping['input'].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0],
}
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
device_mesh_is_1d = False
if device_mesh_is_1d:
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# only for 1D device mesh
if len(self.device_mesh.mesh_shape) == 1:
mesh_dim = 0
else:
mesh_dim = self.device_mesh.mesh_shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim))
else:
# for 2D device mesh
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
# split batch dim of two inputs and the j of the second tensor
# SbSj = Sb x SbSj
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
# split batch dim of two inputs and the k dim of two inputs
# Sb = SbSk x SbSk, need to all-reduce by k dim
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
return strategy_list
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from .strategy_generator import StrategyGenerator
class NormalPoolStrategyGenerator(StrategyGenerator):
"""
NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd.
The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image,
and reduce them depening on the operation type.
"""
def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
kernel_size = self.op_data["other"].data
if isinstance(kernel_size, int):
kernel_size = [kernel_size] * (len(sharded_output_shape) - 2)
kernel_size_product = reduce(operator.mul, kernel_size)
output_size_product = reduce(operator.mul, sharded_output_shape)
input_size_product = reduce(operator.mul, sharded_input_shape)
forward_compute_cost = output_size_product * kernel_size_product
backward_compute_cost = input_size_product * kernel_size_product
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
communication_action_mapping = {}
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def enumerate_all_possible_batch_dimensions_dim_partition(self, mesh_dim_0, mesh_dim_1):
dim_partition_list = []
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, 2))
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, 2))
dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, 2))
# append {} for non_split case
dim_partition_list.append({})
return dim_partition_list
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
return strategy_list
from typing import Dict, List
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import OutputStrategyGenerator
__all__ = ['OutputGenerator']
class OutputGenerator(OutputStrategyGenerator):
"""
OutputGenerator is a generic class to generate strategies for Output Node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node], output_option: str):
super().__init__(operation_data_mapping, device_mesh, predecessor_nodes)
self.output_option = output_option
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
fwd_mem_cost = MemoryCost(activation=0, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=0, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def replica_strategy(self) -> List[ShardingStrategy]:
"""
Generate replica strategy for output node.
"""
dim_partition_dict_mapping = {}
dim_partition_dict_for_output = []
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
if isinstance(self.op_data[mapping_name].data, (tuple, list)):
dim_partition_dict_for_input = [{} for _ in range(len(self.op_data[mapping_name].data))]
else:
dim_partition_dict_for_input = {}
dim_partition_dict_mapping[mapping_name] = dim_partition_dict_for_input
dim_partition_dict_for_output.append(dim_partition_dict_for_input)
if len(dim_partition_dict_for_output) == 1:
dim_partition_dict_for_output = dim_partition_dict_for_output[0]
else:
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Output'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]:
"""
Generate distributed strategy for output node.
"""
# TODO: need to take care of the case when the first element of output only need to be sharded.
output_op_data = self.op_data['output']
if isinstance(output_op_data.data, tuple):
length = len(output_op_data.data)
dim_partition_dict_mapping = {
"output": [{
0: mesh_list
}] * length,
}
else:
dim_partition_dict_mapping = {
"output": {
0: mesh_list
},
}
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
dim_partition_dict_mapping[mapping_name] = {0: mesh_list}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Distributed Output'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
mesh_list = [0, 1]
if self.output_option == 'replicated':
strategy_list.append(self.replica_strategy())
elif self.output_option == 'distributed':
strategy_list.append(self.distributed_strategy(mesh_list))
return strategy_list
from typing import Dict, List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import StrategyGenerator
__all__ = ['PlaceholderGenerator']
class PlaceholderGenerator(StrategyGenerator):
"""
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
placeholder_option: str):
super().__init__(operation_data_mapping, device_mesh)
self.placeholder_option = placeholder_option
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def replica_placeholder(self) -> ShardingStrategy:
"""
Generate replica strategy for placeholder node.
"""
dim_partition_dict_mapping = {
"output": {},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Placeholder'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def distributed_placeholder(self, mesh_list) -> ShardingStrategy:
"""
Generate distributed strategy for placeholder node.
"""
dim_partition_dict_mapping = {
"output": {
0: mesh_list
},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Distributed Placeholder'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
if self.placeholder_option == 'distributed':
mesh_list = [0, 1]
distributed_strategy = self.distributed_placeholder(mesh_list)
strategy_list.append(distributed_strategy)
else:
assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported'
replicated_strategy = self.replica_placeholder()
strategy_list.append(replicated_strategy)
return strategy_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment