Commit 7bc5a8e3 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents e6748d82 0f785cb1
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.logging import get_dist_logger
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']
class GetItemStrategyGenerator(FollowingStrategyGenerator):
"""
GetItemStrategyGenerator is a generic class to generate strategies for operator.getitem.
The operation data is defined as `output = input[other]`.
There are mainly three use cases:
1. args_0._meta_data: torch.Tensor, args_1._meta_data: int
2. args_0._meta_data: torch.Tensor, args_1._meta_data: slice
3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class TensorStrategyGenerator(GetItemStrategyGenerator):
'''
Deal with case 1 and 2.
'''
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
getitem_index = self.op_data['index'].data
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
try:
logger = get_dist_logger()
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_for_input = copy.deepcopy(
strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict)
int_index = False
if isinstance(getitem_index, int):
int_index = True
getitem_dims = [
0,
]
shift_length = 1
elif isinstance(getitem_index, slice):
getitem_dims = [
0,
]
else:
getitem_dims = [i for i in range(len(getitem_index))]
if isinstance(getitem_index[0], int):
int_index = True
shift_length = len(getitem_index)
gather_dims = []
for dim in getitem_dims:
if dim in dim_partition_dict_for_input:
gather_dims.append(dim)
for dim in gather_dims:
dim_partition_dict_for_input.pop(dim)
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
if int_index:
shift_dim_partition_dict_for_output = {}
for dim, mesh_dim_list in dim_partition_dict_for_output.items():
shift_dim_partition_dict_for_output[dim - shift_length] = mesh_dim_list
dim_partition_dict_for_output = shift_dim_partition_dict_for_output
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
except ShardingSpecException as e:
logger.debug(e)
continue
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list
class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
'''
Deal with case 3.
'''
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
index = self.op_data["index"].data
for strategy_index, strategy in enumerate(self.predecessor_node.strategies_vector):
# the sharding spec for input in this case is a tuple of ShardingSpec.
sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_mapping = {
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
sharding_spec_mapping["input"] = sharding_spec_for_input
input_sharding_info = f"get the {index} element from ("
for sharding_spec in sharding_spec_for_input:
input_sharding_info += f'{sharding_spec.sharding_sequence}, '
input_sharding_info += ")"
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
__all__ = ['LayerNormGenerator']
class LayerNormGenerator(StrategyGenerator):
"""
LayerNormGenerator is a generic class to generate strategies for LayerNorm operation.
The operation data is defined as `output = input x other + bias`.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_weight_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)]
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1)
forward_compute_cost = input_batch_product * norm_kernel_product
backward_activation_compute_cost = input_batch_product * norm_kernel_product
# To compute gradient of on norm kernel element requires input_batch_product times computation, so
# the total cost is input_batch_product * norm_kernel_product
backward_weight_compute_cost = input_batch_product * norm_kernel_product
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
if self.has_bias:
forward_compute_cost += bias_compute_cost
backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {
"input": dim_partition,
"other": {},
"output": dim_partition,
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence} x {sharding_spec_mapping["other"].sharding_sequence}'
total_mesh_dim_list = []
for mesh_dim_list in dim_partition.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
communication_action_mapping = {}
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def split_input_batch_single_mesh_dim(self, mesh_dim_0, batch_dimension_length):
strategy_list = []
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
return strategy_list
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1, batch_dimension_length):
strategy_list = []
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
return strategy_list
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_mapping = {
"input": {},
"other": {},
"output": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
'''
strategy_list = []
input_data_dim = len(self.op_data["input"].logical_shape)
weight_data_dim = len(self.op_data["other"].logical_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
batch_dimension_length = input_data_dim - weight_data_dim
# SR = SR x R with single mesh dim on batch dimensions
strategy_list.extend(self.split_input_batch_single_mesh_dim(0, batch_dimension_length))
strategy_list.extend(self.split_input_batch_single_mesh_dim(1, batch_dimension_length))
# SR = SR x R with both mesh dims on batch dimensions
strategy_list.extend(self.split_input_batch_both_mesh_dim(0, 1, batch_dimension_length))
# RR = RR x R
strategy_list.append(self.non_split())
return strategy_list
import operator
from ast import arg
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.options import SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
class MatMulStrategyGenerator(StrategyGenerator):
"""
MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases.
The operation data is defined as `output = input x other + bias`.
"""
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
size_mapping['bias'] = bias_size
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + bias_grad
bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + 0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class DotProductStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}'
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# do not split dimensions for dot product
# R = R dot R
strategy_list.append(self.no_split())
# split two tensors in the same dimensions
# S = S dot S
strategy_list.append(self.split_one_dim(0))
strategy_list.append(self.split_one_dim(1))
return strategy_list
class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@ignore_sharding_exception
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
# get sharding spec
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {},
"output": {
0: [mesh_dim]
},
}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
communication_action_mapping = {}
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=2)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# no split
strategy_list.append(self.no_split())
# split the batch dim for the first tensor only
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
return strategy_list
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def __init__(self,
operation_data_mapping,
device_mesh,
linear_projection_type='linear',
solver_perference=SolverPerference.STANDARD):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
self.solver_perference = solver_perference
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C = AB
# C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul)
# bwd: 2 x fwd_cost
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
dim_m_val = reduce(operator.mul, sharded_input_shape[:-1])
dim_n_val = sharded_other_shape[-1]
dim_p_val = sharded_other_shape[0]
fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=bwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
def dp_strategies(self) -> List[ShardingStrategy]:
strategies = []
# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
return strategies
def tp_strategies(self) -> List[ShardingStrategy]:
strategies = []
# RR = RS01 x S01R
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# RS = RS x SS
strategies.append(self.split_rhs_space_both_contract(0, 1))
strategies.append(self.split_rhs_space_both_contract(1, 0))
# RR= RS x SR
strategies.append(self.recompute_split_both_contract(0))
strategies.append(self.recompute_split_both_contract(1))
# RS = RR x RS
strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1))
return strategies
def mix_strategies(self) -> List[ShardingStrategy]:
strategies = []
# SS = SR x RS
strategies.append(self.split_lhs_space_rhs_space(0, 1))
strategies.append(self.split_lhs_space_rhs_space(1, 0))
# SR = SS x SR
strategies.append(self.split_lhs_space_both_contract(0, 1))
strategies.append(self.split_lhs_space_both_contract(1, 0))
# RR = RR x RR
strategies.append(self.non_split())
return strategies
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
if self.solver_perference == SolverPerference.STANDARD:
strategies.extend(self.dp_strategies())
strategies.extend(self.tp_strategies())
strategies.extend(self.mix_strategies())
elif self.solver_perference == SolverPerference.DP:
strategies.extend(self.dp_strategies())
elif self.solver_perference == SolverPerference.TP:
strategies.extend(self.tp_strategies())
return strategies
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['input'] = input_comm_action
communication_action_mapping['other'] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
# get sharding spec mapping
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action mapping
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
communication_action_mapping['output'] = output_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
# get sharding specs
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
-1: [mesh_dim_1]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping["input"] = input_comm_action
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim]
},
"other": {
0: [mesh_dim]
},
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
-1: [mesh_dim]
},
"bias": {
-1: [mesh_dim]
},
"output": {
-1: [mesh_dim]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
-1: [mesh_dim_0, mesh_dim_1]
},
"bias": {
-1: [mesh_dim_0, mesh_dim_1]
},
"output": {
-1: [mesh_dim_0, mesh_dim_1]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {},
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def validate(self) -> bool:
assert "input" in self.op_data
assert "other" in self.op_data
# make sure the other has 2 dim
input_data = self.op_data['input']
other_data = self.op_data['other']
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
if self.has_bias:
bias_data = self.op_data['bias']
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
"""
Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
The bias term is considered to have a 2D logical shape.
Note: This class will be used to generate strategies for torch.bmm
and torch.addbmm. However, the result of torch.addbmm is not correct,
some extra runtime apply actions are required to keep numerical correctness.
"""
# TODO: torch.addbmm correctness issue need to be fixed.
def __init__(self, *args, **kwargs):
self.squeeze_batch_dim = False
super().__init__(*args, **kwargs)
def _pop_batch_dim_sharding_for_output(self, dim_partition_dict):
# remove partition dict for dim 0
dim_partition_dict['output'].pop(0, None)
# decrease the remaining dim index by 1
temp_dim_partition = {}
keys = list(dim_partition_dict['output'].keys())
for key in keys:
val = dim_partition_dict['output'].pop(key)
temp_dim_partition[key - 1] = val
dim_partition_dict['output'].update(temp_dim_partition)
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
if 'bias' in self.op_data:
bias_op_data = self.op_data['bias']
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
self.op_data['output'].data.shape)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
@ignore_sharding_exception
def split_one_batch_dim(self, mesh_dim):
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
# get sharding_spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
}
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0]
},
"bias": {
0: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1]
}
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
# for addbmm case, other is the third argument instead of second.
communication_action_mapping['other'].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
2: [mesh_dim_1]
},
"bias": {
1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
2: [mesh_dim_1]
}
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
communication_action_mapping['bias'] = bias_comm_action
# for addbmm case, other is the second argument instead of first.
communication_action_mapping['input'].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0],
}
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
device_mesh_is_1d = False
if device_mesh_is_1d:
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# only for 1D device mesh
if len(self.device_mesh.mesh_shape) == 1:
mesh_dim = 0
else:
mesh_dim = self.device_mesh.mesh_shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim))
else:
# for 2D device mesh
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
# split batch dim of two inputs and the j of the second tensor
# SbSj = Sb x SbSj
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
# split batch dim of two inputs and the k dim of two inputs
# Sb = SbSk x SbSk, need to all-reduce by k dim
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
return strategy_list
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from .strategy_generator import StrategyGenerator
class NormalPoolStrategyGenerator(StrategyGenerator):
"""
NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd.
The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image,
and reduce them depening on the operation type.
"""
def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
kernel_size = self.op_data["other"].data
if isinstance(kernel_size, int):
kernel_size = [kernel_size] * (len(sharded_output_shape) - 2)
kernel_size_product = reduce(operator.mul, kernel_size)
output_size_product = reduce(operator.mul, sharded_output_shape)
input_size_product = reduce(operator.mul, sharded_input_shape)
forward_compute_cost = output_size_product * kernel_size_product
backward_compute_cost = input_size_product * kernel_size_product
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
communication_action_mapping = {}
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def enumerate_all_possible_batch_dimensions_dim_partition(self, mesh_dim_0, mesh_dim_1):
dim_partition_list = []
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, 2))
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, 2))
dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, 2))
# append {} for non_split case
dim_partition_list.append({})
return dim_partition_list
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
return strategy_list
from typing import Dict, List
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import OutputStrategyGenerator
__all__ = ['OutputGenerator']
class OutputGenerator(OutputStrategyGenerator):
"""
OutputGenerator is a generic class to generate strategies for Output Node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node], output_option: str):
super().__init__(operation_data_mapping, device_mesh, predecessor_nodes)
self.output_option = output_option
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
fwd_mem_cost = MemoryCost(activation=0, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=0, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def replica_strategy(self) -> List[ShardingStrategy]:
"""
Generate replica strategy for output node.
"""
dim_partition_dict_mapping = {}
dim_partition_dict_for_output = []
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
if isinstance(self.op_data[mapping_name].data, (tuple, list)):
dim_partition_dict_for_input = [{} for _ in range(len(self.op_data[mapping_name].data))]
else:
dim_partition_dict_for_input = {}
dim_partition_dict_mapping[mapping_name] = dim_partition_dict_for_input
dim_partition_dict_for_output.append(dim_partition_dict_for_input)
if len(dim_partition_dict_for_output) == 1:
dim_partition_dict_for_output = dim_partition_dict_for_output[0]
else:
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Output'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]:
"""
Generate distributed strategy for output node.
"""
# TODO: need to take care of the case when the first element of output only need to be sharded.
output_op_data = self.op_data['output']
if isinstance(output_op_data.data, tuple):
length = len(output_op_data.data)
dim_partition_dict_mapping = {
"output": [{
0: mesh_list
}] * length,
}
else:
dim_partition_dict_mapping = {
"output": {
0: mesh_list
},
}
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
dim_partition_dict_mapping[mapping_name] = {0: mesh_list}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Distributed Output'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
mesh_list = [0, 1]
if self.output_option == 'replicated':
strategy_list.append(self.replica_strategy())
elif self.output_option == 'distributed':
strategy_list.append(self.distributed_strategy(mesh_list))
return strategy_list
from typing import Dict, List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import StrategyGenerator
__all__ = ['PlaceholderGenerator']
class PlaceholderGenerator(StrategyGenerator):
"""
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
placeholder_option: str):
super().__init__(operation_data_mapping, device_mesh)
self.placeholder_option = placeholder_option
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def replica_placeholder(self) -> ShardingStrategy:
"""
Generate replica strategy for placeholder node.
"""
dim_partition_dict_mapping = {
"output": {},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Placeholder'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def distributed_placeholder(self, mesh_list) -> ShardingStrategy:
"""
Generate distributed strategy for placeholder node.
"""
dim_partition_dict_mapping = {
"output": {
0: mesh_list
},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Distributed Placeholder'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
if self.placeholder_option == 'distributed':
mesh_list = [0, 1]
distributed_strategy = self.distributed_placeholder(mesh_list)
strategy_list.append(distributed_strategy)
else:
assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported'
replicated_strategy = self.replica_placeholder()
strategy_list.append(replicated_strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
class ReshapeGenerator(FollowingStrategyGenerator):
"""
ReshapeGenerator is the base class for all the reshape operation.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
return super().collate_strategies()
class ViewGenerator(ReshapeGenerator):
"""
ViewGenerator deals with the sharding strategies of view op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
origin_shape = self.op_data['input'].data.shape
tgt_shape = self.op_data['tgt_shape'].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
reshape_mapping_dict)
else:
dim_partition_dict_for_output = {}
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
if keep_sharding_status:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
else:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
# add comm action for converting input to fully replicated
total_mesh_dim_list = []
for mesh_dim_list in dim_partition_dict_for_input.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
# the total mesh dim list only has one element, so the shard dim has only one element as well.
shard_dim = list(dim_partition_dict_for_input.keys())[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
input_comm_action.comm_spec.shard_dim = shard_dim
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class PermuteGenerator(ReshapeGenerator):
"""
PermuteGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
permute_dims = self.op_data['permute_dims'].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
if permute_dim in dim_partition_dict_for_input:
dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim]
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class TransposeGenerator(ReshapeGenerator):
"""
TransposeGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
transpose_dims = self.op_data['transpose_dims'].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
if dim == dim_0:
dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0]
elif dim == dim_1:
dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1]
else:
dim_partition_dict_for_output[dim] = sharded_dims
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class SplitGenerator(ReshapeGenerator):
"""
SplitGenerator deals with the sharding strategies of split op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
recover_dims = None
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
split_size, split_dim = self.op_data['split_info'].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
dim_partition_dict_for_output = [
copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
]
assert len(dim_partition_dict_for_output) >= 2
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
# add comm action if the input need to be recovered to replica in the split dimension.
if recover_dims:
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(recover_dims) == 1:
recover_dims = recover_dims[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
input_comm_action.comm_spec.shard_dim = split_dim
elif len(recover_dims) >= 2:
# original sharding spec
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class DefaultReshapeGenerator(ReshapeGenerator):
"""
DefaultReshapeGenerator which deals with the sharding strategies of Reshape Op which have to recover the tensor
to Replica status.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# For default reshape strategy, to keep the computing correctness we keep the
# sharding spec of input is fully replicated. In addition, we will keep the output
# in replica status and let the successor node choose the way to resharding the
# output node. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for reshape function.
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
if isinstance(self.op_data["output"].data, tuple):
dim_partition_dict_for_output = [{} for _ in range(len(self.op_data["output"].data))]
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
total_mesh_dim_list = []
for mesh_dim_list in dim_partition_dict_for_input.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
input_comm_action.comm_spec.shard_dim = total_mesh_dim_list
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
__all__ = ['SoftmaxGenerator']
class SoftmaxGenerator(FollowingStrategyGenerator):
"""
SoftmaxGenerator is used to generate strategies for torch.nn.Softmax or F.softmax.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
Compute the computation cost per device with this specific strategy.
'''
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
forward_compute_cost = output_size_product * 2
backward_compute_cost = input_size_product
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
softmax_dim = self.op_data['softmax_dim'].data
if softmax_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(softmax_dim)
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import operator
from abc import ABC, abstractmethod
from functools import reduce
from typing import Any, Dict, List, Union
import torch
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.utils import convert_dim_partition_dict
class StrategyGenerator(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
TODO: remove the original strategy_generator.py after refactoring
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
# validate the whether operation data is of desired value
self.validate()
@property
def has_bias(self):
"""
A utility method to check for the existence of bias operand for convenience.
"""
return 'bias' in self.op_data
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.PARAM
def is_buffer(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.BUFFER
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec]):
"""
A factory method to produce a ShardingStrategy object.
Args:
sharding_spec_mapping (Dict[str, ShardingSpec]): the mapping between the operation data name and the ShardingSpec object.
communication_action_mapping (Dict[str, CommSpec]): the mapping between the operation data name and the CommSpec object.
"""
sharding_specs = self.replace_op_name_with_op_data(sharding_spec_mapping)
communication_actions = self.replace_op_name_with_op_data(communication_action_mapping)
return ShardingStrategy(name=name, sharding_specs=sharding_specs, communication_actions=communication_actions)
def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]):
"""
A utility method to convert the the dim partition dict to a ShardingSpec object.
Args:
mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary.
Notes:
The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data.
However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as
list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned.
"""
results = {}
for op_data_name, dim_partition_dict in mapping.items():
if op_data_name in self.op_data:
op_data = self.op_data[op_data_name]
def _to_sharding_spec(
data: any, logical_shape: any,
dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
"""
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
if isinstance(data, torch.Tensor):
dim_size = len(logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict)
return sharding_spec
elif isinstance(data, (list, tuple)):
sharding_spec = []
for data_element, logical_shape_element, dim_partition_dict_element in zip(
data, logical_shape, dim_partition_dict):
sharding_spec.append(
_to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
return sharding_spec
else:
return None
sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict)
results[op_data_name] = sharding_spec
return results
def replace_op_name_with_op_data(self, mapping: Dict[str, Any]):
"""
Convert the key of the dictionary from the operation data name to an OperationData object.
"""
results = {}
for k, v in mapping.items():
op_data = self.op_data[k]
results[op_data] = v
return results
def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]]):
"""
A factory method to produce a CommSpec object.
"""
return CommSpec(comm_pattern=communication_pattern,
sharding_spec=sharding_spec,
logical_process_axis=logical_process_axis)
def get_communication_action(self,
sharding_spec: ShardingSpec,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1,
key_for_kwarg: any = None) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec,
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis),
comm_type=comm_type,
arg_index=arg_index,
key_for_kwarg=key_for_kwarg)
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Compute the communication cost involved in the forward and backward iteration.
"""
comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
def _compute_and_add(op_data: OperationData, comm_spec: CommSpec):
num_ele_in_comm = comm_spec.get_comm_cost()
dtype = op_data.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
comm_cost.fwd += num_ele_in_comm['forward']
comm_cost.bwd += num_ele_in_comm['backward']
comm_cost.total += num_ele_in_comm['total']
# check if communication action exists
# if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None:
for operand, comm_action in strategy.communication_actions.items():
if isinstance(comm_action, CommAction):
comm_spec = comm_action.comm_spec
else:
# this condition branch will be removed after all the handler updated.
comm_spec = comm_action
if isinstance(comm_spec, dict):
src_spec = comm_spec['src_spec']
tgt_spec = comm_spec['tgt_spec']
shape_consistency_manager = ShapeConsistencyManager()
_, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec)
for comm_spec_ in comm_action_sequence:
_compute_and_add(operand, comm_spec_)
else:
_compute_and_add(operand, comm_spec)
# update the communication cost attribute in-place
strategy.communication_cost = comm_cost
return strategy
@abstractmethod
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the computation flops.
"""
pass
@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the memory cost in bytes.
"""
pass
def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str):
"""
Compute the size of a tensor in bytes.
Args:
strategy (ShardingStrategy): the ShardingStrategy generated.
key (str): the name of the operation data defined by the generator.
"""
op_data = self.op_data[key]
def _compute_size_in_bytes_helper(sharding_spec, meta_data):
sharded_shape = sharding_spec.get_sharded_shape_per_device()
if len(sharded_shape) == 0:
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = getattr(meta_data, 'dtype')
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return num_elements * size_per_elem_bytes
if isinstance(op_data.data, tuple):
assert isinstance(strategy.sharding_specs[op_data], list), \
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
total_bytes = 0
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
meta_data = op_data.data[index]
if isinstance(meta_data, torch.Tensor):
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
else:
# if meta_data is not a tensor, we count the memroy as 0
element_bytes = 0
total_bytes += element_bytes
else:
if isinstance(op_data.data, torch.Tensor):
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
else:
# if op_data.data is not a tensor, we count the memroy as 0
total_bytes = 0
return total_bytes
def generate(self) -> List[ShardingStrategy]:
"""
Generate all possible sharding strategies for this operation.
"""
strategies = self.collate_strategies()
# some strategies may be None as ignore_sharding_exception may return None
# when ShardingSpecException occurs.
# thus, remove those None values
strategies = [strategy for strategy in strategies if strategy]
# update the costs
# update mete info on cost
# these update methods are all in-place, the default method will do nothing
# the cost info will only be added if the child class overrides these methods
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies
@abstractmethod
def collate_strategies(self) -> List[ShardingStrategy]:
pass
@abstractmethod
def validate(self) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass
class FollowingStrategyGenerator(StrategyGenerator):
"""
FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node.
TODO: remove the original strategy_generator.py after refactoring
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_node: Node):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
self.predecessor_node = predecessor_node
class OutputStrategyGenerator(StrategyGenerator):
"""
OutputStrategyGenerator is used to generate the sharding strategies for Output Node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node]):
super().__init__(operation_data_mapping, device_mesh)
self.predecessor_nodes = predecessor_nodes
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['SumGenerator']
class SumGenerator(FollowingStrategyGenerator):
"""
SumGenerator deals with the sharding strategies of torch.sum op.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
compute_cost = TrainCycleItem(fwd=input_size_product,
bwd=output_size_product,
total=input_size_product + output_size_product)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
sum_dims, sum_mapping_dict = self.op_data['sum_info'].data
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# among all the shard groups
recover_dims = []
dim_partition_dict_for_output = {}
for dim in dim_partition_dict_for_input:
if dim in sum_dims:
recover_dims.append(dim)
elif dim in sum_mapping_dict:
dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim]
else:
raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims')
for dim in recover_dims:
dim_partition_dict_for_input.pop(dim)
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
from .strategy_generator import StrategyGenerator
__all__ = ['TensorConstructorGenerator']
class TensorConstructorGenerator(StrategyGenerator):
"""
TensorConstructorGenerator which deals with
the sharding strategies for tensor constructor operation, such as torch.arange.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
dim_partition_dict_mapping = {
"output": {},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Tensor Constructor'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['UnaryElementwiseGenerator']
class UnaryElementwiseGenerator(FollowingStrategyGenerator):
"""
UnaryElementwiseGenerator which deals with the sharding strategies of UnaryElementwiseOp.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise function.
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from .strategy_generator import StrategyGenerator
__all__ = ['WhereGenerator']
class WhereGenerator(StrategyGenerator):
"""
WhereGenerator is a generic class to generate strategies for Where operation.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'condition': self._compute_size_in_bytes(strategy, "condition"),
'x': self._compute_size_in_bytes(strategy, "x"),
'y': self._compute_size_in_bytes(strategy, "y"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = condition + x + y + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
# compute bwd cost incurred
# bwd_cost = condition_grad + x_grad + y_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {
"condition": dim_partition,
"x": dim_partition,
"y": dim_partition,
"output": dim_partition
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}'
communication_action_mapping = {}
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def enumerate_all_possible_output_spec(self, mesh_dim_0, mesh_dim_1, dimension_length):
dim_partition_list = []
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, dimension_length))
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, dimension_length))
dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dimension_length))
# append {} for non_split case
dim_partition_list.append({})
return dim_partition_list
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
'''
strategy_list = []
dimension_length = len(self.op_data["output"].logical_shape)
dim_partition_list = self.enumerate_all_possible_output_spec(0, 1, dimension_length)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
return strategy_list
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, SumGenerator
__all__ = ['SumHandler']
@operator_registry.register(torch.Tensor.sum)
@operator_registry.register(torch.sum)
class SumHandler(NodeHandler):
"""
A SumHandler which deals with the sharding strategies for torch.sum or torch.Tensor.sum.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SumGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
if len(self.node.args) > 1:
sum_dims = self.node.args[1]
else:
sum_dims = tuple(range(self.node.args[0]._meta_data.dim()))
if isinstance(sum_dims, int):
sum_dims = (sum_dims,)
# recover negative value to positive
num_dims = self.node.args[0]._meta_data.dim()
for i in range(len(sum_dims)):
if sum_dims[i] < 0:
sum_dims[i] += num_dims
# mapping the input dims to output dims
# For examples:
# input: torch.rand(2, 3, 4, 5)
# output: torch.sum(input, (0, 2))
# sum_mapping_dict = {1: 0, 3: 1}
# sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input
# sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input
sum_mapping_dict = {}
if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']:
for i in range(num_dims):
sum_mapping_dict.update({i: i})
else:
output_index = 0
for i in range(num_dims):
if i not in sum_dims:
sum_mapping_dict.update({i: output_index})
output_index += 1
assert output_index == self.node._meta_data.dim()
sum_info = (sum_dims, sum_mapping_dict)
physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"sum_info": physical_shape_operand,
"output": physical_output_operand
}
return mapping
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator
from .strategy.tensor_constructor_generator import TensorConstructorGenerator
__all__ = ['TensorConstructorHandler']
@operator_registry.register(torch.arange)
class TensorConstructorHandler(NodeHandler):
"""
A TensorConstructorHandler which deals with the sharding strategies for tensor constructor operations, such as torch.arange.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(TensorConstructorGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {"output": physical_output_operand}
return mapping
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, TransposeGenerator
__all__ = ['TransposeHandler']
@operator_registry.register(torch.Tensor.transpose)
@operator_registry.register(torch.transpose)
class TransposeHandler(NodeHandler):
"""
A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(TransposeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
transpose_dims = []
# torch.transpose (input, dim0, dim1)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
transpose_dims.append(arg._meta_data)
else:
transpose_dims.append(arg)
num_dims = self.node._meta_data.dim()
for i in range(2):
# recover negative value to positive
if transpose_dims[i] < 0:
transpose_dims[i] += num_dims
physical_shape_operand = OperationData(name='transpose_dims',
type=OperationDataType.ARG,
data=list(transpose_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"transpose_dims": physical_shape_operand,
"output": physical_output_operand
}
return mapping
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
__all__ = ['UnaryElementwiseHandler']
@operator_registry.register(torch.Tensor.to)
@operator_registry.register(torch.Tensor.type)
@operator_registry.register(torch.abs)
@operator_registry.register(torch.nn.ReLU)
@operator_registry.register(torch.nn.Tanh)
@operator_registry.register(torch.tanh)
@operator_registry.register(torch.nn.modules.dropout.Dropout)
@operator_registry.register(torch.Tensor.contiguous)
@operator_registry.register(torch.nn.functional.dropout)
class UnaryElementwiseHandler(MetaInfoNodeHandler):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(UnaryElementwiseGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "output": physical_output}
return mapping
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, ViewGenerator
__all__ = ['ViewHandler']
@operator_registry.register(torch.Tensor.reshape)
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.Tensor.view)
class ViewHandler(NodeHandler):
"""
A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
target_shape = self.node._meta_data.shape
physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"tgt_shape": physical_shape_operand,
"output": physical_output_operand
}
return mapping
import copy
import operator
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, WhereGenerator
__all__ = ['WhereHandler']
@operator_registry.register(torch.where)
class WhereHandler(NodeHandler):
"""
A WhereHandler which deals with the sharding strategies for torch.where.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
logical_op_data_mapping, _ = self.get_operation_data_mapping()
generators = []
generators.append(WhereGenerator(logical_op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_condition_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_x_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.node.args[1]._meta_data)
physical_y_operand = OperationData(name=str(self.node.args[2]),
type=OperationDataType.ARG,
data=self.node.args[2]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
physical_mapping = {
"condition": physical_condition_operand,
"x": physical_x_operand,
"y": physical_y_operand,
"output": physical_output
}
logical_shape_for_all = self.node._meta_data.shape
logical_mapping = {}
for key, physical_operand in physical_mapping.items():
logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand,
logical_shape_for_all)
return logical_mapping, physical_mapping
def convert_physical_operand_to_logical_operand(self, physical_operand, target_shape):
logical_operand = copy.deepcopy(physical_operand)
logical_operand.logical_shape = target_shape
return logical_operand
def post_process(self, strategy: ShardingStrategy):
logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping()
for key in logical_op_data_mapping.keys():
logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]]
logical_shape = logical_op_data_mapping[key].logical_shape
physical_shape = physical_op_data_mapping[key].logical_shape
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec, logical_shape, physical_shape)
strategy.sharding_specs.pop(logical_op_data_mapping[key])
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
return strategy
from dataclasses import dataclass
from enum import Enum
__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption']
class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
STANDARD = 0
DP = 1
TP = 2
class ShardOption(Enum):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
"""
STANDARD = 0
SHARD = 1
SHARD_LAST_AXIS = 2
FULL_SHARD = 3
class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
REPLICATED = 0
DISTRIBUTED = 1
@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
shard_option: ShardOption = ShardOption.STANDARD
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment