Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
class Registry:
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
......@@ -19,7 +16,7 @@ class Registry:
return wrapper
def get(self, source):
assert source in self.store, f'{source} not found in the {self.name} registry'
assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source]
return target
......@@ -27,4 +24,4 @@ class Registry:
return source in self.store
operator_registry = Registry('operator')
operator_registry = Registry("operator")
......@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import SoftmaxGenerator, StrategyGenerator
__all__ = ['SoftmaxHandler']
__all__ = ["SoftmaxHandler"]
@operator_registry.register(torch.nn.Softmax)
......@@ -34,14 +34,14 @@ class SoftmaxHandler(NodeHandler):
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
softmax_dim = self.node.kwargs['dim']
softmax_dim = self.node.kwargs["dim"]
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if softmax_dim < 0:
softmax_dim += num_dims
physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim)
physical_dim_operand = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
......@@ -49,7 +49,7 @@ class SoftmaxHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"softmax_dim": physical_dim_operand,
"output": physical_output_operand
"output": physical_output_operand,
}
return mapping
......@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import SplitGenerator, StrategyGenerator
__all__ = ['SplitHandler']
__all__ = ["SplitHandler"]
@operator_registry.register(torch.Tensor.split)
......@@ -38,7 +38,7 @@ class SplitHandler(NodeHandler):
split_dim = self.node.args[2]
else:
if self.node.kwargs:
split_dim = self.node.kwargs['dim']
split_dim = self.node.kwargs["dim"]
else:
split_dim = 0
......@@ -48,7 +48,7 @@ class SplitHandler(NodeHandler):
split_dim += num_dims
split_info = (split_size, split_dim)
physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
physical_shape_operand = OperationData(name="split_info", type=OperationDataType.ARG, data=split_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
......@@ -56,7 +56,7 @@ class SplitHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"split_info": physical_shape_operand,
"output": physical_output_operand
"output": physical_output_operand,
}
return mapping
......@@ -29,11 +29,31 @@ from .unary_elementwise_generator import UnaryElementwiseGenerator
from .where_generator import WhereGenerator
__all__ = [
'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator',
'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator',
'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator',
'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator'
"StrategyGenerator",
"DotProductStrategyGenerator",
"MatVecStrategyGenerator",
"LinearProjectionStrategyGenerator",
"BatchedMatMulStrategyGenerator",
"ConvStrategyGenerator",
"UnaryElementwiseGenerator",
"BatchNormStrategyGenerator",
"GetItemStrategyGenerator",
"TensorStrategyGenerator",
"TensorTupleStrategyGenerator",
"LayerNormGenerator",
"PlaceholderGenerator",
"OutputGenerator",
"WhereGenerator",
"NormalPoolStrategyGenerator",
"BinaryElementwiseStrategyGenerator",
"GetattrGenerator",
"TensorConstructorGenerator",
"EmbeddingStrategyGenerator",
"SumGenerator",
"SoftmaxGenerator",
"ViewGenerator",
"PermuteGenerator",
"TransposeGenerator",
"SplitGenerator",
"DefaultReshapeGenerator",
]
......@@ -14,7 +14,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
__all__ = ['BatchNormStrategyGenerator']
__all__ = ["BatchNormStrategyGenerator"]
class BatchNormStrategyGenerator(StrategyGenerator):
......@@ -24,34 +24,37 @@ class BatchNormStrategyGenerator(StrategyGenerator):
To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
2. We do the SyncBatchNorm on the each input partition separately, the SyncBN op will help
us to keep the computing correctness.
In this generator, both methods will be considered.
"""
def validate(self) -> bool:
'''
"""
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
"""
input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
3,
4,
5,
), f"We suppose the dim of input fed into conv op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
"""
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
......@@ -69,23 +72,24 @@ class BatchNormStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output"),
'running_mean': self._compute_size_in_bytes(strategy, "running_mean"),
'running_var': self._compute_size_in_bytes(strategy, "running_var"),
"input": self._compute_size_in_bytes(strategy, "input"),
"other": self._compute_size_in_bytes(strategy, "other"),
"output": self._compute_size_in_bytes(strategy, "output"),
"running_mean": self._compute_size_in_bytes(strategy, "running_mean"),
"running_var": self._compute_size_in_bytes(strategy, "running_var"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum(
[v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
[v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]
)
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost)
......@@ -93,36 +97,29 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum(
[v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
[v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]
)
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost,
buffer=fwd_buffer_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost,
buffer=fwd_buffer_cost,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
name = f"RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0]
},
"output": {
1: [mesh_dim_0]
},
"running_mean": {
0: [mesh_dim_0]
},
"running_var": {
0: [mesh_dim_0]
},
"input": {1: [mesh_dim_0]},
"other": {0: [mesh_dim_0]},
"output": {1: [mesh_dim_0]},
"running_mean": {0: [mesh_dim_0]},
"running_var": {0: [mesh_dim_0]},
"num_batches_tracked": {},
}
if self.has_bias:
......@@ -132,29 +129,21 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
name = f"RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"output": {
1: [mesh_dim_0, mesh_dim_1]
},
"running_mean": {
0: [mesh_dim_0, mesh_dim_1]
},
"running_var": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {1: [mesh_dim_0, mesh_dim_1]},
"other": {0: [mesh_dim_0, mesh_dim_1]},
"output": {1: [mesh_dim_0, mesh_dim_1]},
"running_mean": {0: [mesh_dim_0, mesh_dim_1]},
"running_var": {0: [mesh_dim_0, mesh_dim_1]},
"num_batches_tracked": {},
}
if self.has_bias:
......@@ -164,13 +153,15 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
name = f"RR = RR x R"
dim_partition_dict_mapping = {
"input": {},
"other": {},
......@@ -186,21 +177,19 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0]
},
"output": {0: [mesh_dim_0]},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
......@@ -212,33 +201,32 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action
# For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.IMPLICIT)
comm_type=CommType.IMPLICIT,
)
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
"output": {0: [mesh_dim_0, mesh_dim_1]},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
......@@ -250,25 +238,28 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.IMPLICIT)
comm_type=CommType.IMPLICIT,
)
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
......@@ -298,26 +289,29 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0],
comm_type=CommType.IMPLICIT)
comm_type=CommType.IMPLICIT,
)
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
'''
"""
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
'''
"""
strategy_list = []
# RS = RS x S
......
......@@ -14,7 +14,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import StrategyGenerator
__all__ = ['BinaryElementwiseStrategyGenerator']
__all__ = ["BinaryElementwiseStrategyGenerator"]
class BinaryElementwiseStrategyGenerator(StrategyGenerator):
......@@ -26,36 +26,37 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
assert len(self.op_data) == 3, \
f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}'
assert (
len(self.op_data) == 3
), f"BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}"
for name, op_data in self.op_data.items():
if not isinstance(op_data.data, (torch.Tensor, int, float)):
raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.')
raise TypeError(f"The operation data {name} is not a torch.Tensor/int/float.")
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
# since elementwise ops are not compute-intensive,
# we approximate the backward compute cost
# to be twice the fwd compute cost
fwd_compute_cost = reduce(operator.mul, shape)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# all input, output and outputs have the same shape
shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
# compute fwd memory cost in bytes
# as the elementwise ops are not memory-intensive
# we approximate the fwd memroy cost to be the output
# we approximate the fwd memory cost to be the output
# and the backward memory cost to be grad of input and other
input_bytes = self._compute_size_in_bytes(strategy, 'input')
other_bytes = self._compute_size_in_bytes(strategy, 'other')
output_bytes = self._compute_size_in_bytes(strategy, 'output')
input_bytes = self._compute_size_in_bytes(strategy, "input")
other_bytes = self._compute_size_in_bytes(strategy, "other")
output_bytes = self._compute_size_in_bytes(strategy, "output")
fwd_memory_cost = MemoryCost(activation=output_bytes)
bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes)
total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes)
......@@ -66,7 +67,7 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
dim_size = len(self.op_data['output'].logical_shape)
dim_size = len(self.op_data["output"].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
......@@ -86,21 +87,22 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
# convert these dim partition dict to sharding strategy
for dim_partition_dict in dim_partition_list:
dim_partition_dict_mapping = dict(input=dim_partition_dict,
other=dim_partition_dict,
output=dim_partition_dict)
dim_partition_dict_mapping = dict(
input=dim_partition_dict, other=dim_partition_dict, output=dim_partition_dict
)
try:
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
# get name
sharding_seq = sharding_spec_mapping['input'].sharding_sequence
name = f'{sharding_seq} = {sharding_seq} <binary-elementwise-op> {sharding_seq}'
sharding_seq = sharding_spec_mapping["input"].sharding_sequence
name = f"{sharding_seq} = {sharding_seq} <binary-elementwise-op> {sharding_seq}"
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
......
import copy
import operator
import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
......@@ -24,29 +22,32 @@ class ConvStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
'''
"""
In sanity check, we need make sure the input data having correct dimension size.
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
"""
input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
3,
4,
5,
), f"We suppose the dim of input fed into conv op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
"""
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
......@@ -76,14 +77,14 @@ class ConvStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"other": self._compute_size_in_bytes(strategy, "other"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
......@@ -100,26 +101,20 @@ class ConvStrategyGenerator(StrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {
1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
"input": {0: [mesh_dim_0]},
"other": {1: [mesh_dim_1]},
"output": {0: [mesh_dim_0], 1: [mesh_dim_1]},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]}
......@@ -132,7 +127,8 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
......@@ -140,7 +136,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -148,38 +145,41 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x RR"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0],
......@@ -196,7 +196,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -204,42 +205,45 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
"other": {
0: [mesh_dim_1]
},
"other": {0: [mesh_dim_1]},
"output": {
0: [mesh_dim_0],
},
......@@ -254,7 +258,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
communication_action_mapping = {"output": output_comm_action}
......@@ -263,7 +268,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -271,7 +277,8 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param("bias"):
......@@ -279,23 +286,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
......@@ -322,23 +333,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"output": output_comm_action, "input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
name = f"RR = RS{mesh_dim_0} x S{mesh_dim_0}R"
dim_partition_dict_mapping = {
"input": {
......@@ -360,17 +375,20 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
name = f"RS{mesh_dim_0} = RR x RS{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {},
......@@ -395,17 +413,20 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
name = f"RR = RR x RR"
dim_partition_dict_mapping = {
"input": {},
......@@ -418,13 +439,13 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
return self.get_sharding_strategy(
name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
)
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR"
dim_partition_dict_mapping = {
"input": {
......@@ -447,14 +468,16 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
......@@ -464,23 +487,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R"
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1],
......@@ -501,17 +528,20 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {},
"other": {
......@@ -535,13 +565,16 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
......
import copy
import operator
import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
......@@ -27,16 +25,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the computation cost per device with this specific strategy.
Note: The computation cost for the embedding handler is estimated as dense computing now.
It may not be accurate.
'''
"""
# TODO: estimate the embedding computation cost as sparse operation
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
other_size_product = reduce(operator.mul, sharded_other_shape)
......@@ -55,9 +53,9 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"other": self._compute_size_in_bytes(strategy, "other"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -75,14 +73,15 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def non_split(self):
name = f'RR = R x RR'
name = f"RR = R x RR"
dim_partition_dict_mapping = {
"input": {},
......@@ -92,18 +91,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
return self.get_sharding_strategy(
name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
)
@ignore_sharding_exception
def split_input(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR'
name = f"S{mesh_dim_0}R = S{mesh_dim_0} x RR"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0],
......@@ -118,7 +115,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -126,17 +124,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}'
name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
......@@ -159,7 +160,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
......@@ -167,7 +169,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -175,22 +178,23 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR'
name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1],
......@@ -207,7 +211,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -215,17 +220,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_embedding_dim(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}'
name = f"RS{mesh_dim_0} = R x RS{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {},
......@@ -245,17 +253,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}'
name = f"RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {},
......@@ -275,13 +286,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
......
......@@ -10,7 +10,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import StrategyGenerator
__all__ = ['GetattrGenerator']
__all__ = ["GetattrGenerator"]
class GetattrGenerator(StrategyGenerator):
......@@ -26,10 +26,10 @@ class GetattrGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
"""
forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
......@@ -47,7 +47,7 @@ class GetattrGenerator(StrategyGenerator):
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
dim_size = len(self.op_data['output'].logical_shape)
dim_size = len(self.op_data["output"].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
......@@ -78,7 +78,8 @@ class GetattrGenerator(StrategyGenerator):
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
......
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.logging import get_dist_logger
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']
__all__ = ["GetItemStrategyGenerator", "TensorStrategyGenerator", "TensorTupleStrategyGenerator"]
class GetItemStrategyGenerator(FollowingStrategyGenerator):
......@@ -35,12 +29,12 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -58,27 +52,29 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class TensorStrategyGenerator(GetItemStrategyGenerator):
'''
"""
Deal with case 1 and 2.
'''
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
getitem_index = self.op_data['index'].data
getitem_index = self.op_data["index"].data
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
try:
logger = get_dist_logger()
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_for_input = copy.deepcopy(
strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict)
strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
)
int_index = False
if isinstance(getitem_index, int):
......@@ -120,9 +116,11 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
except ShardingSpecException as e:
logger.debug(e)
continue
......@@ -137,9 +135,9 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
'''
"""
Deal with case 3.
'''
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
......@@ -158,13 +156,15 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
sharding_spec_mapping["input"] = sharding_spec_for_input
input_sharding_info = f"get the {index} element from ("
for sharding_spec in sharding_spec_for_input:
input_sharding_info += f'{sharding_spec.sharding_sequence}, '
input_sharding_info += f"{sharding_spec.sharding_sequence}, "
input_sharding_info += ")"
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
......
......@@ -18,7 +18,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
__all__ = ['LayerNormGenerator']
__all__ = ["LayerNormGenerator"]
class LayerNormGenerator(StrategyGenerator):
......@@ -31,21 +31,21 @@ class LayerNormGenerator(StrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
"""
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_weight_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_weight_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)]
input_batch_shape = sharded_input_shape[: -len(sharded_weight_shape)]
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1)
forward_compute_cost = input_batch_product * norm_kernel_product
......@@ -62,18 +62,18 @@ class LayerNormGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"other": self._compute_size_in_bytes(strategy, "other"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
......@@ -90,8 +90,9 @@ class LayerNormGenerator(StrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
......@@ -120,7 +121,8 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
......@@ -128,12 +130,15 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
communication_action_mapping["bias"] = bias_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
......@@ -155,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator):
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
name = f"RR = RR x R"
dim_partition_dict_mapping = {
"input": {},
"other": {},
......@@ -168,14 +173,16 @@ class LayerNormGenerator(StrategyGenerator):
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
'''
"""
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
'''
"""
strategy_list = []
input_data_dim = len(self.op_data["input"].logical_shape)
weight_data_dim = len(self.op_data["other"].logical_shape)
......
import operator
from ast import arg
from functools import reduce
from typing import List
......@@ -24,14 +23,14 @@ class MatMulStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"other": self._compute_size_in_bytes(strategy, "other"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
size_mapping['bias'] = bias_size
size_mapping["bias"] = bias_size
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
......@@ -41,45 +40,47 @@ class MatMulStrategyGenerator(StrategyGenerator):
# compute bwd cost incurred
# bwd_cost = input_grad + bias_grad
bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']])
bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ["input", "other", "bias"]])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + 0)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + 0
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class DotProductStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
input_op_data = self.op_data["input"]
other_op_data = self.op_data["other"]
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
name = f"R = R dot R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}'
name = f"R = S{mesh_dim} dot S{mesh_dim}"
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}}
......@@ -87,14 +88,17 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
# get communication action
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
......@@ -112,19 +116,18 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
input_op_data = self.op_data["input"]
other_op_data = self.op_data["other"]
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
return compute_cost
@ignore_sharding_exception
......@@ -133,67 +136,69 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
dim_partition_dict['bias'] = {}
dim_partition_dict["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
return self.get_sharding_strategy(
name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
name = f"S{mesh_dim}R = S{mesh_dim}R x R"
# get sharding spec
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"input": {0: [mesh_dim]},
"other": {},
"output": {
0: [mesh_dim]
},
"output": {0: [mesh_dim]},
}
if self.has_bias:
dim_partition_dict['bias'] = {}
dim_partition_dict["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
communication_action_mapping = {}
if self.is_param('other'):
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=2)
communication_action_mapping['bias'] = bias_comm_action
arg_index=2,
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
......@@ -209,12 +214,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def __init__(self,
operation_data_mapping,
device_mesh,
linear_projection_type='linear',
solver_perference=SolverPerference.STANDARD):
def __init__(
self,
operation_data_mapping,
device_mesh,
linear_projection_type="linear",
solver_perference=SolverPerference.STANDARD,
):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
self.solver_perference = solver_perference
......@@ -224,17 +230,17 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul)
# bwd: 2 x fwd_cost
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
dim_m_val = reduce(operator.mul, sharded_input_shape[:-1])
dim_n_val = sharded_other_shape[-1]
dim_p_val = sharded_other_shape[0]
fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=bwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=bwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
strategy.compute_cost = compute_cost
def dp_strategies(self) -> List[ShardingStrategy]:
......@@ -301,28 +307,21 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"input": {0: [mesh_dim_0]},
"other": {-1: [mesh_dim_1]},
"output": {0: [mesh_dim_0], -1: [mesh_dim_1]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
if self.linear_projection_type == "linear":
dim_partition_dict_mapping["bias"] = {-1: [mesh_dim_1]}
elif self.linear_projection_type == "addmm":
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
......@@ -333,75 +332,75 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
if self.is_param('other'):
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping['input'] = input_comm_action
communication_action_mapping['other'] = other_comm_action
communication_action_mapping["input"] = input_comm_action
communication_action_mapping["other"] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
if self.has_bias and self.linear_projection_type == "linear":
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R"
# get sharding spec mapping
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_1]
},
"input": {0: [mesh_dim_0], -1: [mesh_dim_1]},
"other": {0: [mesh_dim_1]},
"bias": {},
"output": {
0: [mesh_dim_0]
},
"output": {0: [mesh_dim_0]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]}
if self.linear_projection_type == "linear":
dim_partition_dict_mapping["bias"] = {}
elif self.linear_projection_type == "addmm":
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]}
else:
raise ('Unsupported linear projection type')
raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
......@@ -412,66 +411,64 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
if self.is_param('other'):
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping['other'] = other_comm_action
communication_action_mapping['output'] = output_comm_action
communication_action_mapping["other"] = other_comm_action
communication_action_mapping["output"] = output_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
if self.has_bias and self.linear_projection_type == "linear":
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}"
# get sharding specs
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
-1: [mesh_dim_1]
},
"input": {-1: [mesh_dim_0]},
"other": {0: [mesh_dim_0], -1: [mesh_dim_1]},
"bias": {-1: [mesh_dim_1]},
"output": {-1: [mesh_dim_1]},
}
# We don't have to do anything special for bias here, because
......@@ -482,34 +479,34 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping["input"] = input_comm_action
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
communication_action_mapping["output"] = output_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
name = f"RR = RS{mesh_dim} x S{mesh_dim}R"
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim]
},
"other": {
0: [mesh_dim]
},
"input": {-1: [mesh_dim]},
"other": {0: [mesh_dim]},
"bias": {},
"output": {},
}
......@@ -520,32 +517,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
communication_action_mapping["output"] = output_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
name = f"RS{mesh_dim} = RR x RS{mesh_dim}"
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
-1: [mesh_dim]
},
"bias": {
-1: [mesh_dim]
},
"output": {
-1: [mesh_dim]
},
"other": {-1: [mesh_dim]},
"bias": {-1: [mesh_dim]},
"output": {-1: [mesh_dim]},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
......@@ -554,93 +548,94 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
communication_action_mapping["input"] = input_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR"
# get sharding spec
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
"output": {0: [mesh_dim_0, mesh_dim_1]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]}
if self.linear_projection_type == "linear":
dim_partition_dict_mapping["bias"] = {}
elif self.linear_projection_type == "addmm":
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
if self.is_param('other'):
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
if self.has_bias and self.linear_projection_type == "linear":
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R"
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {-1: [mesh_dim_0, mesh_dim_1]},
"other": {0: [mesh_dim_0, mesh_dim_1]},
"bias": {},
"output": {},
}
......@@ -652,32 +647,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
comm_type=CommType.AFTER,
)
communication_action_mapping["output"] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}"
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
-1: [mesh_dim_0, mesh_dim_1]
},
"bias": {
-1: [mesh_dim_0, mesh_dim_1]
},
"output": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {-1: [mesh_dim_0, mesh_dim_1]},
"bias": {-1: [mesh_dim_0, mesh_dim_1]},
"output": {-1: [mesh_dim_0, mesh_dim_1]},
}
# We don't have to do anything special for bias here, because
......@@ -687,20 +679,23 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
arg_index=0,
)
communication_action_mapping["input"] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
name = f"RR = RR x RR"
# get sharding spec
dim_partition_dict_mapping = {
......@@ -717,22 +712,24 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def validate(self) -> bool:
assert "input" in self.op_data
assert "other" in self.op_data
# make sure the other has 2 dim
input_data = self.op_data['input']
other_data = self.op_data['other']
input_data = self.op_data["input"]
other_data = self.op_data["other"]
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
if self.has_bias:
bias_data = self.op_data['bias']
bias_data = self.op_data["bias"]
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
......@@ -757,37 +754,38 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def _pop_batch_dim_sharding_for_output(self, dim_partition_dict):
# remove partition dict for dim 0
dim_partition_dict['output'].pop(0, None)
dim_partition_dict["output"].pop(0, None)
# decrease the remaining dim index by 1
temp_dim_partition = {}
keys = list(dim_partition_dict['output'].keys())
keys = list(dim_partition_dict["output"].keys())
for key in keys:
val = dim_partition_dict['output'].pop(key)
val = dim_partition_dict["output"].pop(key)
temp_dim_partition[key - 1] = val
dim_partition_dict['output'].update(temp_dim_partition)
dim_partition_dict["output"].update(temp_dim_partition)
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
input_op_data = self.op_data["input"]
other_op_data = self.op_data["other"]
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
if 'bias' in self.op_data:
bias_op_data = self.op_data['bias']
if "bias" in self.op_data:
bias_op_data = self.op_data["bias"]
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
self.op_data['output'].data.shape)
fwd_compute_cost = self.op_data["input"].data.shape[-1] * reduce(
operator.mul, self.op_data["output"].data.shape
)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
strategy.compute_cost = compute_cost
@ignore_sharding_exception
def split_one_batch_dim(self, mesh_dim):
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
name = f"Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}"
# get sharding_spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
......@@ -799,30 +797,27 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
arg_index=0,
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
name = f"Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {0: [mesh_dim_0, mesh_dim_1]},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
}
"output": {0: [mesh_dim_0, mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
......@@ -832,35 +827,28 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
arg_index=0,
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
name = f"Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}"
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0]
},
"bias": {
0: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1]
}
"input": {0: [mesh_dim_0], 1: [mesh_dim_1]},
"other": {0: [mesh_dim_0]},
"bias": {0: [mesh_dim_1]},
"output": {0: [mesh_dim_0], 1: [mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
......@@ -869,46 +857,40 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
arg_index=0,
)
communication_action_mapping["bias"] = bias_comm_action
# for addbmm case, other is the third argument instead of second.
communication_action_mapping['other'].arg_index += 1
communication_action_mapping["other"].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
name = f"Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}"
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
2: [mesh_dim_1]
},
"bias": {
1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
2: [mesh_dim_1]
}
"input": {0: [mesh_dim_0]},
"other": {0: [mesh_dim_0], 2: [mesh_dim_1]},
"bias": {1: [mesh_dim_1]},
"output": {0: [mesh_dim_0], 2: [mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
......@@ -917,43 +899,41 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
arg_index=0,
)
communication_action_mapping["input"] = input_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
communication_action_mapping['bias'] = bias_comm_action
comm_type=CommType.BEFORE,
)
communication_action_mapping["bias"] = bias_comm_action
# for addbmm case, other is the second argument instead of first.
communication_action_mapping['input'].arg_index += 1
communication_action_mapping["input"].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
name = f"Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}"
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
"input": {0: [mesh_dim_0], 2: [mesh_dim_1]},
"other": {0: [mesh_dim_0], 1: [mesh_dim_1]},
"bias": {},
"output": {
0: [mesh_dim_0],
}
},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
......@@ -962,29 +942,33 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
comm_type=CommType.AFTER,
)
communication_action_mapping["output"] = output_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
arg_index=0,
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape:
device_mesh_is_1d = False
if device_mesh_is_1d:
......@@ -992,10 +976,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# only for 1D device mesh
if len(self.device_mesh.mesh_shape) == 1:
if len(self.device_mesh.shape) == 1:
mesh_dim = 0
else:
mesh_dim = self.device_mesh.mesh_shape.index(1)
mesh_dim = self.device_mesh.shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim))
else:
# for 2D device mesh
......
......@@ -17,32 +17,35 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
"""
NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd.
The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image,
and reduce them depening on the operation type.
and reduce them depending on the operation type.
"""
def validate(self) -> bool:
'''
"""
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
"""
input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
3,
4,
5,
), f"We suppose the dim of input fed into Pool op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
'''
"""
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
"""
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
kernel_size = self.op_data["other"].data
if isinstance(kernel_size, int):
......@@ -61,8 +64,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -88,12 +91,16 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
name = (
f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
)
communication_action_mapping = {}
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
......
......@@ -12,7 +12,7 @@ from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import OutputStrategyGenerator
__all__ = ['OutputGenerator']
__all__ = ["OutputGenerator"]
class OutputGenerator(OutputStrategyGenerator):
......@@ -20,8 +20,13 @@ class OutputGenerator(OutputStrategyGenerator):
OutputGenerator is a generic class to generate strategies for Output Node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node], output_option: str):
def __init__(
self,
operation_data_mapping: Dict[str, OperationData],
device_mesh: DeviceMesh,
predecessor_nodes: List[Node],
output_option: str,
):
super().__init__(operation_data_mapping, device_mesh, predecessor_nodes)
self.output_option = output_option
......@@ -33,9 +38,9 @@ class OutputGenerator(OutputStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
fwd_mem_cost = MemoryCost(activation=0, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
......@@ -65,16 +70,18 @@ class OutputGenerator(OutputStrategyGenerator):
else:
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
dim_partition_dict_mapping["output"] = dim_partition_dict_for_output
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Output'
name = "Replica Output"
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]:
......@@ -82,19 +89,15 @@ class OutputGenerator(OutputStrategyGenerator):
Generate distributed strategy for output node.
"""
# TODO: need to take care of the case when the first element of output only need to be sharded.
output_op_data = self.op_data['output']
output_op_data = self.op_data["output"]
if isinstance(output_op_data.data, tuple):
length = len(output_op_data.data)
dim_partition_dict_mapping = {
"output": [{
0: mesh_list
}] * length,
"output": [{0: mesh_list}] * length,
}
else:
dim_partition_dict_mapping = {
"output": {
0: mesh_list
},
"output": {0: mesh_list},
}
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
......@@ -103,19 +106,21 @@ class OutputGenerator(OutputStrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Distributed Output'
name = "Distributed Output"
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
mesh_list = [0, 1]
if self.output_option == 'replicated':
if self.output_option == "replicated":
strategy_list.append(self.replica_strategy())
elif self.output_option == 'distributed':
elif self.output_option == "distributed":
strategy_list.append(self.distributed_strategy(mesh_list))
return strategy_list
......@@ -10,7 +10,7 @@ from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import StrategyGenerator
__all__ = ['PlaceholderGenerator']
__all__ = ["PlaceholderGenerator"]
class PlaceholderGenerator(StrategyGenerator):
......@@ -18,8 +18,9 @@ class PlaceholderGenerator(StrategyGenerator):
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
placeholder_option: str):
def __init__(
self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, placeholder_option: str
):
super().__init__(operation_data_mapping, device_mesh)
self.placeholder_option = placeholder_option
......@@ -31,10 +32,10 @@ class PlaceholderGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
"""
forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
......@@ -58,11 +59,13 @@ class PlaceholderGenerator(StrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Placeholder'
name = "Replica Placeholder"
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
......@@ -71,29 +74,31 @@ class PlaceholderGenerator(StrategyGenerator):
Generate distributed strategy for placeholder node.
"""
dim_partition_dict_mapping = {
"output": {
0: mesh_list
},
"output": {0: mesh_list},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Distributed Placeholder'
name = "Distributed Placeholder"
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
if self.placeholder_option == 'distributed':
if self.placeholder_option == "distributed":
mesh_list = [0, 1]
distributed_strategy = self.distributed_placeholder(mesh_list)
strategy_list.append(distributed_strategy)
else:
assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported'
assert (
self.placeholder_option == "replicated"
), f"placeholder_option {self.placeholder_option} is not supported"
replicated_strategy = self.replica_placeholder()
strategy_list.append(replicated_strategy)
......
......@@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
__all__ = ["ReshapeGenerator", "ViewGenerator", "PermuteGenerator", "TransposeGenerator", "SplitGenerator"]
class ReshapeGenerator(FollowingStrategyGenerator):
......@@ -33,12 +33,12 @@ class ReshapeGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -56,8 +56,9 @@ class ReshapeGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
......@@ -77,8 +78,8 @@ class ViewGenerator(ReshapeGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
origin_shape = self.op_data['input'].data.shape
tgt_shape = self.op_data['tgt_shape'].data
origin_shape = self.op_data["input"].data.shape
tgt_shape = self.op_data["tgt_shape"].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
......@@ -86,8 +87,9 @@ class ViewGenerator(ReshapeGenerator):
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
reshape_mapping_dict)
dim_partition_dict_for_output = infer_output_dim_partition_dict(
dim_partition_dict_for_input, reshape_mapping_dict
)
else:
dim_partition_dict_for_output = {}
......@@ -119,7 +121,8 @@ class ViewGenerator(ReshapeGenerator):
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
......@@ -127,10 +130,10 @@ class ViewGenerator(ReshapeGenerator):
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
target_spec = ShardingSpec(
device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={}
)
comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
......@@ -139,9 +142,11 @@ class ViewGenerator(ReshapeGenerator):
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -159,7 +164,7 @@ class PermuteGenerator(ReshapeGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
permute_dims = self.op_data['permute_dims'].data
permute_dims = self.op_data["permute_dims"].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
......@@ -177,9 +182,11 @@ class PermuteGenerator(ReshapeGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -199,7 +206,7 @@ class TransposeGenerator(ReshapeGenerator):
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
transpose_dims = self.op_data['transpose_dims'].data
transpose_dims = self.op_data["transpose_dims"].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
......@@ -221,9 +228,11 @@ class TransposeGenerator(ReshapeGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -242,7 +251,7 @@ class SplitGenerator(ReshapeGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
split_size, split_dim = self.op_data['split_info'].data
split_size, split_dim = self.op_data["split_info"].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
......@@ -271,7 +280,8 @@ class SplitGenerator(ReshapeGenerator):
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
......@@ -282,7 +292,7 @@ class SplitGenerator(ReshapeGenerator):
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
......@@ -291,9 +301,11 @@ class SplitGenerator(ReshapeGenerator):
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -341,16 +353,17 @@ class DefaultReshapeGenerator(ReshapeGenerator):
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
input_comm_action.comm_spec.shard_dim = total_mesh_dim_list
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
target_spec = ShardingSpec(
device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={}
)
comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
......@@ -358,9 +371,11 @@ class DefaultReshapeGenerator(ReshapeGenerator):
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -4,21 +4,9 @@ from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
__all__ = ['SoftmaxGenerator']
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
__all__ = ["SoftmaxGenerator"]
class SoftmaxGenerator(FollowingStrategyGenerator):
......@@ -30,11 +18,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the computation cost per device with this specific strategy.
'''
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
"""
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
......@@ -45,12 +33,12 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -68,8 +56,9 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
......@@ -80,10 +69,10 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
softmax_dim = self.op_data['softmax_dim'].data
softmax_dim = self.op_data["softmax_dim"].data
if softmax_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(softmax_dim)
dim_partition_dict_for_input.pop(softmax_dim)
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
......@@ -96,9 +85,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -39,7 +39,7 @@ class StrategyGenerator(ABC):
"""
A utility method to check for the existence of bias operand for convenience.
"""
return 'bias' in self.op_data
return "bias" in self.op_data
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
......@@ -49,8 +49,12 @@ class StrategyGenerator(ABC):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.BUFFER
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec]):
def get_sharding_strategy(
self,
name: str,
sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec],
):
"""
A factory method to produce a ShardingStrategy object.
......@@ -80,24 +84,28 @@ class StrategyGenerator(ABC):
op_data = self.op_data[op_data_name]
def _to_sharding_spec(
data: any, logical_shape: any,
dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
data: any, logical_shape: any, dim_partition_dict: Dict[int, List[int]]
) -> Union[ShardingSpec, List[ShardingSpec], None]:
"""
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
if isinstance(data, torch.Tensor):
dim_size = len(logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict)
sharding_spec = ShardingSpec(
device_mesh=self.device_mesh,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict,
)
return sharding_spec
elif isinstance(data, (list, tuple)):
sharding_spec = []
for data_element, logical_shape_element, dim_partition_dict_element in zip(
data, logical_shape, dim_partition_dict):
data, logical_shape, dim_partition_dict
):
sharding_spec.append(
_to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
_to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)
)
return sharding_spec
else:
return None
......@@ -116,31 +124,41 @@ class StrategyGenerator(ABC):
results[op_data] = v
return results
def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]]):
def get_communication_spec(
self,
sharding_spec: ShardingSpec,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
):
"""
A factory method to produce a CommSpec object.
"""
return CommSpec(comm_pattern=communication_pattern,
sharding_spec=sharding_spec,
logical_process_axis=logical_process_axis)
def get_communication_action(self,
sharding_spec: ShardingSpec,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1,
key_for_kwarg: any = None) -> CommAction:
return CommSpec(
comm_pattern=communication_pattern, sharding_spec=sharding_spec, logical_process_axis=logical_process_axis
)
def get_communication_action(
self,
sharding_spec: ShardingSpec,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1,
key_for_kwarg: any = None,
) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec,
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis),
comm_type=comm_type,
arg_index=arg_index,
key_for_kwarg=key_for_kwarg)
return CommAction(
comm_spec=self.get_communication_spec(
sharding_spec=sharding_spec,
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis,
),
comm_type=comm_type,
arg_index=arg_index,
key_for_kwarg=key_for_kwarg,
)
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
......@@ -155,9 +173,9 @@ class StrategyGenerator(ABC):
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
comm_cost.fwd += num_ele_in_comm['forward']
comm_cost.bwd += num_ele_in_comm['backward']
comm_cost.total += num_ele_in_comm['total']
comm_cost.fwd += num_ele_in_comm["forward"]
comm_cost.bwd += num_ele_in_comm["backward"]
comm_cost.total += num_ele_in_comm["total"]
# check if communication action exists
# if so, loop over each action and compute the cost of each action
......@@ -169,8 +187,8 @@ class StrategyGenerator(ABC):
# this condition branch will be removed after all the handler updated.
comm_spec = comm_action
if isinstance(comm_spec, dict):
src_spec = comm_spec['src_spec']
tgt_spec = comm_spec['tgt_spec']
src_spec = comm_spec["src_spec"]
tgt_spec = comm_spec["tgt_spec"]
shape_consistency_manager = ShapeConsistencyManager()
_, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec)
for comm_spec_ in comm_action_sequence:
......@@ -187,14 +205,12 @@ class StrategyGenerator(ABC):
"""
Customize this method to compute the computation flops.
"""
pass
@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the memory cost in bytes.
"""
pass
def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str):
"""
......@@ -212,20 +228,21 @@ class StrategyGenerator(ABC):
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = getattr(meta_data, 'dtype')
dtype = getattr(meta_data, "dtype")
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return num_elements * size_per_elem_bytes
if isinstance(op_data.data, tuple):
assert isinstance(strategy.sharding_specs[op_data], list), \
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
assert isinstance(
strategy.sharding_specs[op_data], list
), "sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple."
total_bytes = 0
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
meta_data = op_data.data[index]
if isinstance(meta_data, torch.Tensor):
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
else:
# if meta_data is not a tensor, we count the memroy as 0
# if meta_data is not a tensor, we count the memory as 0
element_bytes = 0
total_bytes += element_bytes
......@@ -233,7 +250,7 @@ class StrategyGenerator(ABC):
if isinstance(op_data.data, torch.Tensor):
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
else:
# if op_data.data is not a tensor, we count the memroy as 0
# if op_data.data is not a tensor, we count the memory as 0
total_bytes = 0
return total_bytes
......@@ -270,7 +287,6 @@ class StrategyGenerator(ABC):
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass
class FollowingStrategyGenerator(StrategyGenerator):
......@@ -280,8 +296,9 @@ class FollowingStrategyGenerator(StrategyGenerator):
TODO: remove the original strategy_generator.py after refactoring
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_node: Node):
def __init__(
self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_node: Node
):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
self.predecessor_node = predecessor_node
......@@ -292,7 +309,8 @@ class OutputStrategyGenerator(StrategyGenerator):
OutputStrategyGenerator is used to generate the sharding strategies for Output Node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node]):
def __init__(
self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_nodes: List[Node]
):
super().__init__(operation_data_mapping, device_mesh)
self.predecessor_nodes = predecessor_nodes
......@@ -4,22 +4,9 @@ from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['SumGenerator']
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
__all__ = ["SumGenerator"]
class SumGenerator(FollowingStrategyGenerator):
......@@ -31,24 +18,24 @@ class SumGenerator(FollowingStrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
compute_cost = TrainCycleItem(fwd=input_size_product,
bwd=output_size_product,
total=input_size_product + output_size_product)
compute_cost = TrainCycleItem(
fwd=input_size_product, bwd=output_size_product, total=input_size_product + output_size_product
)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -66,8 +53,9 @@ class SumGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
......@@ -78,7 +66,7 @@ class SumGenerator(FollowingStrategyGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
sum_dims, sum_mapping_dict = self.op_data['sum_info'].data
sum_dims, sum_mapping_dict = self.op_data["sum_info"].data
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# among all the shard groups
......@@ -90,7 +78,7 @@ class SumGenerator(FollowingStrategyGenerator):
elif dim in sum_mapping_dict:
dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim]
else:
raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims')
raise RuntimeError(f"dim {dim} is not in sum_mapping_dict or sum_dims")
for dim in recover_dims:
dim_partition_dict_for_input.pop(dim)
......@@ -105,9 +93,11 @@ class SumGenerator(FollowingStrategyGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from .strategy_generator import StrategyGenerator
__all__ = ['TensorConstructorGenerator']
__all__ = ["TensorConstructorGenerator"]
class TensorConstructorGenerator(StrategyGenerator):
......@@ -30,10 +21,10 @@ class TensorConstructorGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
"""
forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = input + output
......@@ -57,11 +48,13 @@ class TensorConstructorGenerator(StrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Tensor Constructor'
name = "Replica Tensor Constructor"
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment