"vllm/vscode:/vscode.git/clone" did not exist on "42b06117ddc8eba95f978a10990659670121b488"
Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
......@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import SoftmaxGenerator, StrategyGenerator
__all__ = ['SoftmaxHandler']
__all__ = ["SoftmaxHandler"]
@operator_registry.register(torch.nn.Softmax)
......@@ -34,14 +34,14 @@ class SoftmaxHandler(NodeHandler):
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
softmax_dim = self.node.kwargs['dim']
softmax_dim = self.node.kwargs["dim"]
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if softmax_dim < 0:
softmax_dim += num_dims
physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim)
physical_dim_operand = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
......@@ -49,7 +49,7 @@ class SoftmaxHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"softmax_dim": physical_dim_operand,
"output": physical_output_operand
"output": physical_output_operand,
}
return mapping
......@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import SplitGenerator, StrategyGenerator
__all__ = ['SplitHandler']
__all__ = ["SplitHandler"]
@operator_registry.register(torch.Tensor.split)
......@@ -38,7 +38,7 @@ class SplitHandler(NodeHandler):
split_dim = self.node.args[2]
else:
if self.node.kwargs:
split_dim = self.node.kwargs['dim']
split_dim = self.node.kwargs["dim"]
else:
split_dim = 0
......@@ -48,7 +48,7 @@ class SplitHandler(NodeHandler):
split_dim += num_dims
split_info = (split_size, split_dim)
physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
physical_shape_operand = OperationData(name="split_info", type=OperationDataType.ARG, data=split_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
......@@ -56,7 +56,7 @@ class SplitHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"split_info": physical_shape_operand,
"output": physical_output_operand
"output": physical_output_operand,
}
return mapping
......@@ -29,11 +29,31 @@ from .unary_elementwise_generator import UnaryElementwiseGenerator
from .where_generator import WhereGenerator
__all__ = [
'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator',
'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator',
'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator',
'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator'
"StrategyGenerator",
"DotProductStrategyGenerator",
"MatVecStrategyGenerator",
"LinearProjectionStrategyGenerator",
"BatchedMatMulStrategyGenerator",
"ConvStrategyGenerator",
"UnaryElementwiseGenerator",
"BatchNormStrategyGenerator",
"GetItemStrategyGenerator",
"TensorStrategyGenerator",
"TensorTupleStrategyGenerator",
"LayerNormGenerator",
"PlaceholderGenerator",
"OutputGenerator",
"WhereGenerator",
"NormalPoolStrategyGenerator",
"BinaryElementwiseStrategyGenerator",
"GetattrGenerator",
"TensorConstructorGenerator",
"EmbeddingStrategyGenerator",
"SumGenerator",
"SoftmaxGenerator",
"ViewGenerator",
"PermuteGenerator",
"TransposeGenerator",
"SplitGenerator",
"DefaultReshapeGenerator",
]
......@@ -14,7 +14,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
__all__ = ['BatchNormStrategyGenerator']
__all__ = ["BatchNormStrategyGenerator"]
class BatchNormStrategyGenerator(StrategyGenerator):
......@@ -30,28 +30,31 @@ class BatchNormStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
'''
"""
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
"""
input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
3,
4,
5,
), f"We suppose the dim of input fed into conv op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
'''
"""
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
......@@ -69,23 +72,24 @@ class BatchNormStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output"),
'running_mean': self._compute_size_in_bytes(strategy, "running_mean"),
'running_var': self._compute_size_in_bytes(strategy, "running_var"),
"input": self._compute_size_in_bytes(strategy, "input"),
"other": self._compute_size_in_bytes(strategy, "other"),
"output": self._compute_size_in_bytes(strategy, "output"),
"running_mean": self._compute_size_in_bytes(strategy, "running_mean"),
"running_var": self._compute_size_in_bytes(strategy, "running_var"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum(
[v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
[v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]
)
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost)
......@@ -93,36 +97,29 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum(
[v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
[v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]
)
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost,
buffer=fwd_buffer_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost,
buffer=fwd_buffer_cost,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
name = f"RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0]
},
"output": {
1: [mesh_dim_0]
},
"running_mean": {
0: [mesh_dim_0]
},
"running_var": {
0: [mesh_dim_0]
},
"input": {1: [mesh_dim_0]},
"other": {0: [mesh_dim_0]},
"output": {1: [mesh_dim_0]},
"running_mean": {0: [mesh_dim_0]},
"running_var": {0: [mesh_dim_0]},
"num_batches_tracked": {},
}
if self.has_bias:
......@@ -132,29 +129,21 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
name = f"RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"output": {
1: [mesh_dim_0, mesh_dim_1]
},
"running_mean": {
0: [mesh_dim_0, mesh_dim_1]
},
"running_var": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {1: [mesh_dim_0, mesh_dim_1]},
"other": {0: [mesh_dim_0, mesh_dim_1]},
"output": {1: [mesh_dim_0, mesh_dim_1]},
"running_mean": {0: [mesh_dim_0, mesh_dim_1]},
"running_var": {0: [mesh_dim_0, mesh_dim_1]},
"num_batches_tracked": {},
}
if self.has_bias:
......@@ -164,13 +153,15 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
name = f"RR = RR x R"
dim_partition_dict_mapping = {
"input": {},
"other": {},
......@@ -186,21 +177,19 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0]
},
"output": {0: [mesh_dim_0]},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
......@@ -218,27 +207,26 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.IMPLICIT)
comm_type=CommType.IMPLICIT,
)
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
"output": {0: [mesh_dim_0, mesh_dim_1]},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
......@@ -256,19 +244,22 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.IMPLICIT)
comm_type=CommType.IMPLICIT,
)
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
......@@ -304,20 +295,23 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0],
comm_type=CommType.IMPLICIT)
comm_type=CommType.IMPLICIT,
)
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
'''
"""
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
'''
"""
strategy_list = []
# RS = RS x S
......
......@@ -14,7 +14,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import StrategyGenerator
__all__ = ['BinaryElementwiseStrategyGenerator']
__all__ = ["BinaryElementwiseStrategyGenerator"]
class BinaryElementwiseStrategyGenerator(StrategyGenerator):
......@@ -26,36 +26,37 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
assert len(self.op_data) == 3, \
f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}'
assert (
len(self.op_data) == 3
), f"BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}"
for name, op_data in self.op_data.items():
if not isinstance(op_data.data, (torch.Tensor, int, float)):
raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.')
raise TypeError(f"The operation data {name} is not a torch.Tensor/int/float.")
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
# since elementwise ops are not compute-intensive,
# we approximate the backward compute cost
# to be twice the fwd compute cost
fwd_compute_cost = reduce(operator.mul, shape)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# all input, output and outputs have the same shape
shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
# compute fwd memory cost in bytes
# as the elementwise ops are not memory-intensive
# we approximate the fwd memory cost to be the output
# and the backward memory cost to be grad of input and other
input_bytes = self._compute_size_in_bytes(strategy, 'input')
other_bytes = self._compute_size_in_bytes(strategy, 'other')
output_bytes = self._compute_size_in_bytes(strategy, 'output')
input_bytes = self._compute_size_in_bytes(strategy, "input")
other_bytes = self._compute_size_in_bytes(strategy, "other")
output_bytes = self._compute_size_in_bytes(strategy, "output")
fwd_memory_cost = MemoryCost(activation=output_bytes)
bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes)
total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes)
......@@ -66,7 +67,7 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
dim_size = len(self.op_data['output'].logical_shape)
dim_size = len(self.op_data["output"].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
......@@ -86,21 +87,22 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
# convert these dim partition dict to sharding strategy
for dim_partition_dict in dim_partition_list:
dim_partition_dict_mapping = dict(input=dim_partition_dict,
other=dim_partition_dict,
output=dim_partition_dict)
dim_partition_dict_mapping = dict(
input=dim_partition_dict, other=dim_partition_dict, output=dim_partition_dict
)
try:
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
# get name
sharding_seq = sharding_spec_mapping['input'].sharding_sequence
name = f'{sharding_seq} = {sharding_seq} <binary-elementwise-op> {sharding_seq}'
sharding_seq = sharding_spec_mapping["input"].sharding_sequence
name = f"{sharding_seq} = {sharding_seq} <binary-elementwise-op> {sharding_seq}"
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
......
import copy
import operator
import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
......@@ -24,29 +22,32 @@ class ConvStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
'''
"""
In sanity check, we need make sure the input data having correct dimension size.
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
"""
input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
3,
4,
5,
), f"We suppose the dim of input fed into conv op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
'''
"""
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
......@@ -76,14 +77,14 @@ class ConvStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"other": self._compute_size_in_bytes(strategy, "other"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
......@@ -100,26 +101,20 @@ class ConvStrategyGenerator(StrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {
1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
"input": {0: [mesh_dim_0]},
"other": {1: [mesh_dim_1]},
"output": {0: [mesh_dim_0], 1: [mesh_dim_1]},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]}
......@@ -132,7 +127,8 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
......@@ -140,7 +136,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -148,38 +145,41 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x RR"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0],
......@@ -196,7 +196,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -204,42 +205,45 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
"other": {
0: [mesh_dim_1]
},
"other": {0: [mesh_dim_1]},
"output": {
0: [mesh_dim_0],
},
......@@ -254,7 +258,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
communication_action_mapping = {"output": output_comm_action}
......@@ -263,7 +268,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -271,7 +277,8 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param("bias"):
......@@ -279,23 +286,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
......@@ -322,23 +333,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"output": output_comm_action, "input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
name = f"RR = RS{mesh_dim_0} x S{mesh_dim_0}R"
dim_partition_dict_mapping = {
"input": {
......@@ -360,17 +375,20 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
name = f"RS{mesh_dim_0} = RR x RS{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {},
......@@ -395,17 +413,20 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
name = f"RR = RR x RR"
dim_partition_dict_mapping = {
"input": {},
......@@ -418,13 +439,13 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
return self.get_sharding_strategy(
name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
)
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR"
dim_partition_dict_mapping = {
"input": {
......@@ -447,14 +468,16 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
......@@ -464,23 +487,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R"
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1],
......@@ -501,17 +528,20 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {},
"other": {
......@@ -535,13 +565,16 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
......
import copy
import operator
import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
......@@ -27,16 +25,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the computation cost per device with this specific strategy.
Note: The computation cost for the embedding handler is estimated as dense computing now.
It may not be accurate.
'''
"""
# TODO: estimate the embedding computation cost as sparse operation
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
other_size_product = reduce(operator.mul, sharded_other_shape)
......@@ -55,9 +53,9 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"other": self._compute_size_in_bytes(strategy, "other"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -75,14 +73,15 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def non_split(self):
name = f'RR = R x RR'
name = f"RR = R x RR"
dim_partition_dict_mapping = {
"input": {},
......@@ -92,18 +91,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
return self.get_sharding_strategy(
name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
)
@ignore_sharding_exception
def split_input(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR'
name = f"S{mesh_dim_0}R = S{mesh_dim_0} x RR"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0],
......@@ -118,7 +115,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -126,17 +124,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}'
name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
......@@ -159,7 +160,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
......@@ -167,7 +169,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -175,22 +178,23 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR'
name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1],
......@@ -207,7 +211,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
......@@ -215,17 +220,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_embedding_dim(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}'
name = f"RS{mesh_dim_0} = R x RS{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {},
......@@ -245,17 +253,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}'
name = f"RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {},
......@@ -275,13 +286,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
......
......@@ -10,7 +10,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import StrategyGenerator
__all__ = ['GetattrGenerator']
__all__ = ["GetattrGenerator"]
class GetattrGenerator(StrategyGenerator):
......@@ -26,10 +26,10 @@ class GetattrGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
"""
forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
......@@ -47,7 +47,7 @@ class GetattrGenerator(StrategyGenerator):
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
dim_size = len(self.op_data['output'].logical_shape)
dim_size = len(self.op_data["output"].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
......@@ -78,7 +78,8 @@ class GetattrGenerator(StrategyGenerator):
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
......
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.logging import get_dist_logger
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']
__all__ = ["GetItemStrategyGenerator", "TensorStrategyGenerator", "TensorTupleStrategyGenerator"]
class GetItemStrategyGenerator(FollowingStrategyGenerator):
......@@ -35,12 +29,12 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -58,27 +52,29 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class TensorStrategyGenerator(GetItemStrategyGenerator):
'''
"""
Deal with case 1 and 2.
'''
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
getitem_index = self.op_data['index'].data
getitem_index = self.op_data["index"].data
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
try:
logger = get_dist_logger()
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_for_input = copy.deepcopy(
strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict)
strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
)
int_index = False
if isinstance(getitem_index, int):
......@@ -120,9 +116,11 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
except ShardingSpecException as e:
logger.debug(e)
continue
......@@ -137,9 +135,9 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
'''
"""
Deal with case 3.
'''
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
......@@ -158,13 +156,15 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
sharding_spec_mapping["input"] = sharding_spec_for_input
input_sharding_info = f"get the {index} element from ("
for sharding_spec in sharding_spec_for_input:
input_sharding_info += f'{sharding_spec.sharding_sequence}, '
input_sharding_info += f"{sharding_spec.sharding_sequence}, "
input_sharding_info += ")"
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
......
......@@ -18,7 +18,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
__all__ = ['LayerNormGenerator']
__all__ = ["LayerNormGenerator"]
class LayerNormGenerator(StrategyGenerator):
......@@ -31,21 +31,21 @@ class LayerNormGenerator(StrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
'''
"""
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_weight_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_weight_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)]
input_batch_shape = sharded_input_shape[: -len(sharded_weight_shape)]
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1)
forward_compute_cost = input_batch_product * norm_kernel_product
......@@ -62,18 +62,18 @@ class LayerNormGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"other": self._compute_size_in_bytes(strategy, "other"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
......@@ -90,8 +90,9 @@ class LayerNormGenerator(StrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
......@@ -120,7 +121,8 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
......@@ -128,12 +130,15 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
communication_action_mapping["bias"] = bias_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
......@@ -155,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator):
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
name = f"RR = RR x R"
dim_partition_dict_mapping = {
"input": {},
"other": {},
......@@ -168,14 +173,16 @@ class LayerNormGenerator(StrategyGenerator):
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
'''
"""
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
'''
"""
strategy_list = []
input_data_dim = len(self.op_data["input"].logical_shape)
weight_data_dim = len(self.op_data["other"].logical_shape)
......
import operator
from ast import arg
from functools import reduce
from typing import List
......@@ -24,14 +23,14 @@ class MatMulStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"other": self._compute_size_in_bytes(strategy, "other"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
size_mapping['bias'] = bias_size
size_mapping["bias"] = bias_size
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
......@@ -41,45 +40,47 @@ class MatMulStrategyGenerator(StrategyGenerator):
# compute bwd cost incurred
# bwd_cost = input_grad + bias_grad
bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']])
bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ["input", "other", "bias"]])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + 0)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + 0
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class DotProductStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
input_op_data = self.op_data["input"]
other_op_data = self.op_data["other"]
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
name = f"R = R dot R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}'
name = f"R = S{mesh_dim} dot S{mesh_dim}"
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}}
......@@ -87,14 +88,17 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
# get communication action
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
......@@ -112,19 +116,18 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
input_op_data = self.op_data["input"]
other_op_data = self.op_data["other"]
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
return compute_cost
@ignore_sharding_exception
......@@ -133,67 +136,69 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
dim_partition_dict['bias'] = {}
dim_partition_dict["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
return self.get_sharding_strategy(
name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
name = f"S{mesh_dim}R = S{mesh_dim}R x R"
# get sharding spec
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"input": {0: [mesh_dim]},
"other": {},
"output": {
0: [mesh_dim]
},
"output": {0: [mesh_dim]},
}
if self.has_bias:
dim_partition_dict['bias'] = {}
dim_partition_dict["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
communication_action_mapping = {}
if self.is_param('other'):
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=2)
communication_action_mapping['bias'] = bias_comm_action
arg_index=2,
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
......@@ -209,12 +214,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def __init__(self,
operation_data_mapping,
device_mesh,
linear_projection_type='linear',
solver_perference=SolverPerference.STANDARD):
def __init__(
self,
operation_data_mapping,
device_mesh,
linear_projection_type="linear",
solver_perference=SolverPerference.STANDARD,
):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
self.solver_perference = solver_perference
......@@ -224,17 +230,17 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul)
# bwd: 2 x fwd_cost
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
dim_m_val = reduce(operator.mul, sharded_input_shape[:-1])
dim_n_val = sharded_other_shape[-1]
dim_p_val = sharded_other_shape[0]
fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=bwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=bwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
strategy.compute_cost = compute_cost
def dp_strategies(self) -> List[ShardingStrategy]:
......@@ -301,28 +307,21 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"input": {0: [mesh_dim_0]},
"other": {-1: [mesh_dim_1]},
"output": {0: [mesh_dim_0], -1: [mesh_dim_1]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
if self.linear_projection_type == "linear":
dim_partition_dict_mapping["bias"] = {-1: [mesh_dim_1]}
elif self.linear_projection_type == "addmm":
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
......@@ -333,75 +332,75 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
if self.is_param('other'):
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping['input'] = input_comm_action
communication_action_mapping['other'] = other_comm_action
communication_action_mapping["input"] = input_comm_action
communication_action_mapping["other"] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
if self.has_bias and self.linear_projection_type == "linear":
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R"
# get sharding spec mapping
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_1]
},
"input": {0: [mesh_dim_0], -1: [mesh_dim_1]},
"other": {0: [mesh_dim_1]},
"bias": {},
"output": {
0: [mesh_dim_0]
},
"output": {0: [mesh_dim_0]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]}
if self.linear_projection_type == "linear":
dim_partition_dict_mapping["bias"] = {}
elif self.linear_projection_type == "addmm":
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]}
else:
raise ('Unsupported linear projection type')
raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
......@@ -412,66 +411,64 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
if self.is_param('other'):
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
arg_index=1,
)
communication_action_mapping['other'] = other_comm_action
communication_action_mapping['output'] = output_comm_action
communication_action_mapping["other"] = other_comm_action
communication_action_mapping["output"] = output_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
if self.has_bias and self.linear_projection_type == "linear":
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}"
# get sharding specs
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
-1: [mesh_dim_1]
},
"input": {-1: [mesh_dim_0]},
"other": {0: [mesh_dim_0], -1: [mesh_dim_1]},
"bias": {-1: [mesh_dim_1]},
"output": {-1: [mesh_dim_1]},
}
# We don't have to do anything special for bias here, because
......@@ -482,34 +479,34 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping["input"] = input_comm_action
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
communication_action_mapping["output"] = output_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
name = f"RR = RS{mesh_dim} x S{mesh_dim}R"
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim]
},
"other": {
0: [mesh_dim]
},
"input": {-1: [mesh_dim]},
"other": {0: [mesh_dim]},
"bias": {},
"output": {},
}
......@@ -520,32 +517,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
comm_type=CommType.AFTER,
)
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
communication_action_mapping["output"] = output_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
name = f"RS{mesh_dim} = RR x RS{mesh_dim}"
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
-1: [mesh_dim]
},
"bias": {
-1: [mesh_dim]
},
"output": {
-1: [mesh_dim]
},
"other": {-1: [mesh_dim]},
"bias": {-1: [mesh_dim]},
"output": {-1: [mesh_dim]},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
......@@ -554,93 +548,94 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
communication_action_mapping["input"] = input_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR"
# get sharding spec
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
"output": {0: [mesh_dim_0, mesh_dim_1]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]}
if self.linear_projection_type == "linear":
dim_partition_dict_mapping["bias"] = {}
elif self.linear_projection_type == "addmm":
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
if self.is_param('other'):
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
if self.has_bias and self.linear_projection_type == "linear":
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
comm_type=CommType.HOOK,
)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
key_for_kwarg="bias",
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R"
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {-1: [mesh_dim_0, mesh_dim_1]},
"other": {0: [mesh_dim_0, mesh_dim_1]},
"bias": {},
"output": {},
}
......@@ -652,32 +647,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
comm_type=CommType.AFTER,
)
communication_action_mapping["output"] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}"
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
-1: [mesh_dim_0, mesh_dim_1]
},
"bias": {
-1: [mesh_dim_0, mesh_dim_1]
},
"output": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {-1: [mesh_dim_0, mesh_dim_1]},
"bias": {-1: [mesh_dim_0, mesh_dim_1]},
"output": {-1: [mesh_dim_0, mesh_dim_1]},
}
# We don't have to do anything special for bias here, because
......@@ -687,20 +679,23 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
arg_index=0,
)
communication_action_mapping["input"] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
name = f"RR = RR x RR"
# get sharding spec
dim_partition_dict_mapping = {
......@@ -717,22 +712,24 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def validate(self) -> bool:
assert "input" in self.op_data
assert "other" in self.op_data
# make sure the other has 2 dim
input_data = self.op_data['input']
other_data = self.op_data['other']
input_data = self.op_data["input"]
other_data = self.op_data["other"]
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
if self.has_bias:
bias_data = self.op_data['bias']
bias_data = self.op_data["bias"]
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
......@@ -757,37 +754,38 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def _pop_batch_dim_sharding_for_output(self, dim_partition_dict):
# remove partition dict for dim 0
dim_partition_dict['output'].pop(0, None)
dim_partition_dict["output"].pop(0, None)
# decrease the remaining dim index by 1
temp_dim_partition = {}
keys = list(dim_partition_dict['output'].keys())
keys = list(dim_partition_dict["output"].keys())
for key in keys:
val = dim_partition_dict['output'].pop(key)
val = dim_partition_dict["output"].pop(key)
temp_dim_partition[key - 1] = val
dim_partition_dict['output'].update(temp_dim_partition)
dim_partition_dict["output"].update(temp_dim_partition)
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
input_op_data = self.op_data["input"]
other_op_data = self.op_data["other"]
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
if 'bias' in self.op_data:
bias_op_data = self.op_data['bias']
if "bias" in self.op_data:
bias_op_data = self.op_data["bias"]
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
self.op_data['output'].data.shape)
fwd_compute_cost = self.op_data["input"].data.shape[-1] * reduce(
operator.mul, self.op_data["output"].data.shape
)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
strategy.compute_cost = compute_cost
@ignore_sharding_exception
def split_one_batch_dim(self, mesh_dim):
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
name = f"Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}"
# get sharding_spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
......@@ -799,30 +797,27 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
arg_index=0,
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
name = f"Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {0: [mesh_dim_0, mesh_dim_1]},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
}
"output": {0: [mesh_dim_0, mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
......@@ -832,35 +827,28 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
arg_index=0,
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
name = f"Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}"
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0]
},
"bias": {
0: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1]
}
"input": {0: [mesh_dim_0], 1: [mesh_dim_1]},
"other": {0: [mesh_dim_0]},
"bias": {0: [mesh_dim_1]},
"output": {0: [mesh_dim_0], 1: [mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
......@@ -869,46 +857,40 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
arg_index=1,
)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
arg_index=0,
)
communication_action_mapping["bias"] = bias_comm_action
# for addbmm case, other is the third argument instead of second.
communication_action_mapping['other'].arg_index += 1
communication_action_mapping["other"].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
name = f"Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}"
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
2: [mesh_dim_1]
},
"bias": {
1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
2: [mesh_dim_1]
}
"input": {0: [mesh_dim_0]},
"other": {0: [mesh_dim_0], 2: [mesh_dim_1]},
"bias": {1: [mesh_dim_1]},
"output": {0: [mesh_dim_0], 2: [mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
......@@ -917,43 +899,41 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
arg_index=0,
)
communication_action_mapping["input"] = input_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
communication_action_mapping['bias'] = bias_comm_action
comm_type=CommType.BEFORE,
)
communication_action_mapping["bias"] = bias_comm_action
# for addbmm case, other is the second argument instead of first.
communication_action_mapping['input'].arg_index += 1
communication_action_mapping["input"].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
@ignore_sharding_exception
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
name = f"Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}"
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
"input": {0: [mesh_dim_0], 2: [mesh_dim_1]},
"other": {0: [mesh_dim_0], 1: [mesh_dim_1]},
"bias": {},
"output": {
0: [mesh_dim_0],
}
},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
......@@ -962,24 +942,28 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
comm_type=CommType.AFTER,
)
communication_action_mapping["output"] = output_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
arg_index=0,
)
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
......
......@@ -21,28 +21,31 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
'''
"""
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
"""
input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
3,
4,
5,
), f"We suppose the dim of input fed into Pool op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
'''
"""
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
'''
"""
# TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
kernel_size = self.op_data["other"].data
if isinstance(kernel_size, int):
......@@ -61,8 +64,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -88,12 +91,16 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
name = (
f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
)
communication_action_mapping = {}
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
......
......@@ -12,7 +12,7 @@ from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import OutputStrategyGenerator
__all__ = ['OutputGenerator']
__all__ = ["OutputGenerator"]
class OutputGenerator(OutputStrategyGenerator):
......@@ -20,8 +20,13 @@ class OutputGenerator(OutputStrategyGenerator):
OutputGenerator is a generic class to generate strategies for Output Node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node], output_option: str):
def __init__(
self,
operation_data_mapping: Dict[str, OperationData],
device_mesh: DeviceMesh,
predecessor_nodes: List[Node],
output_option: str,
):
super().__init__(operation_data_mapping, device_mesh, predecessor_nodes)
self.output_option = output_option
......@@ -33,9 +38,9 @@ class OutputGenerator(OutputStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
fwd_mem_cost = MemoryCost(activation=0, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
......@@ -65,16 +70,18 @@ class OutputGenerator(OutputStrategyGenerator):
else:
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
dim_partition_dict_mapping["output"] = dim_partition_dict_for_output
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Output'
name = "Replica Output"
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]:
......@@ -82,19 +89,15 @@ class OutputGenerator(OutputStrategyGenerator):
Generate distributed strategy for output node.
"""
# TODO: need to take care of the case when the first element of output only need to be sharded.
output_op_data = self.op_data['output']
output_op_data = self.op_data["output"]
if isinstance(output_op_data.data, tuple):
length = len(output_op_data.data)
dim_partition_dict_mapping = {
"output": [{
0: mesh_list
}] * length,
"output": [{0: mesh_list}] * length,
}
else:
dim_partition_dict_mapping = {
"output": {
0: mesh_list
},
"output": {0: mesh_list},
}
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
......@@ -103,19 +106,21 @@ class OutputGenerator(OutputStrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Distributed Output'
name = "Distributed Output"
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
mesh_list = [0, 1]
if self.output_option == 'replicated':
if self.output_option == "replicated":
strategy_list.append(self.replica_strategy())
elif self.output_option == 'distributed':
elif self.output_option == "distributed":
strategy_list.append(self.distributed_strategy(mesh_list))
return strategy_list
......@@ -10,7 +10,7 @@ from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import StrategyGenerator
__all__ = ['PlaceholderGenerator']
__all__ = ["PlaceholderGenerator"]
class PlaceholderGenerator(StrategyGenerator):
......@@ -18,8 +18,9 @@ class PlaceholderGenerator(StrategyGenerator):
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
placeholder_option: str):
def __init__(
self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, placeholder_option: str
):
super().__init__(operation_data_mapping, device_mesh)
self.placeholder_option = placeholder_option
......@@ -31,10 +32,10 @@ class PlaceholderGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
"""
forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
......@@ -58,11 +59,13 @@ class PlaceholderGenerator(StrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Placeholder'
name = "Replica Placeholder"
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
......@@ -71,29 +74,31 @@ class PlaceholderGenerator(StrategyGenerator):
Generate distributed strategy for placeholder node.
"""
dim_partition_dict_mapping = {
"output": {
0: mesh_list
},
"output": {0: mesh_list},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Distributed Placeholder'
name = "Distributed Placeholder"
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
if self.placeholder_option == 'distributed':
if self.placeholder_option == "distributed":
mesh_list = [0, 1]
distributed_strategy = self.distributed_placeholder(mesh_list)
strategy_list.append(distributed_strategy)
else:
assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported'
assert (
self.placeholder_option == "replicated"
), f"placeholder_option {self.placeholder_option} is not supported"
replicated_strategy = self.replica_placeholder()
strategy_list.append(replicated_strategy)
......
......@@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
__all__ = ["ReshapeGenerator", "ViewGenerator", "PermuteGenerator", "TransposeGenerator", "SplitGenerator"]
class ReshapeGenerator(FollowingStrategyGenerator):
......@@ -33,12 +33,12 @@ class ReshapeGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -56,8 +56,9 @@ class ReshapeGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
......@@ -77,8 +78,8 @@ class ViewGenerator(ReshapeGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
origin_shape = self.op_data['input'].data.shape
tgt_shape = self.op_data['tgt_shape'].data
origin_shape = self.op_data["input"].data.shape
tgt_shape = self.op_data["tgt_shape"].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
......@@ -86,8 +87,9 @@ class ViewGenerator(ReshapeGenerator):
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
reshape_mapping_dict)
dim_partition_dict_for_output = infer_output_dim_partition_dict(
dim_partition_dict_for_input, reshape_mapping_dict
)
else:
dim_partition_dict_for_output = {}
......@@ -119,7 +121,8 @@ class ViewGenerator(ReshapeGenerator):
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
......@@ -127,10 +130,10 @@ class ViewGenerator(ReshapeGenerator):
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
target_spec = ShardingSpec(
device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={}
)
comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
......@@ -139,9 +142,11 @@ class ViewGenerator(ReshapeGenerator):
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -159,7 +164,7 @@ class PermuteGenerator(ReshapeGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
permute_dims = self.op_data['permute_dims'].data
permute_dims = self.op_data["permute_dims"].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
......@@ -177,9 +182,11 @@ class PermuteGenerator(ReshapeGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -199,7 +206,7 @@ class TransposeGenerator(ReshapeGenerator):
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
transpose_dims = self.op_data['transpose_dims'].data
transpose_dims = self.op_data["transpose_dims"].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
......@@ -221,9 +228,11 @@ class TransposeGenerator(ReshapeGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -242,7 +251,7 @@ class SplitGenerator(ReshapeGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
split_size, split_dim = self.op_data['split_info'].data
split_size, split_dim = self.op_data["split_info"].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
......@@ -271,7 +280,8 @@ class SplitGenerator(ReshapeGenerator):
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
......@@ -282,7 +292,7 @@ class SplitGenerator(ReshapeGenerator):
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
......@@ -291,9 +301,11 @@ class SplitGenerator(ReshapeGenerator):
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -341,16 +353,17 @@ class DefaultReshapeGenerator(ReshapeGenerator):
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
arg_index=0,
)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
input_comm_action.comm_spec.shard_dim = total_mesh_dim_list
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
target_spec = ShardingSpec(
device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={}
)
comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
......@@ -358,9 +371,11 @@ class DefaultReshapeGenerator(ReshapeGenerator):
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -4,21 +4,9 @@ from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
__all__ = ['SoftmaxGenerator']
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
__all__ = ["SoftmaxGenerator"]
class SoftmaxGenerator(FollowingStrategyGenerator):
......@@ -30,11 +18,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the computation cost per device with this specific strategy.
'''
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
"""
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
......@@ -45,12 +33,12 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -68,8 +56,9 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
......@@ -80,10 +69,10 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
softmax_dim = self.op_data['softmax_dim'].data
softmax_dim = self.op_data["softmax_dim"].data
if softmax_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(softmax_dim)
dim_partition_dict_for_input.pop(softmax_dim)
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
......@@ -96,9 +85,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -39,7 +39,7 @@ class StrategyGenerator(ABC):
"""
A utility method to check for the existence of bias operand for convenience.
"""
return 'bias' in self.op_data
return "bias" in self.op_data
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
......@@ -49,8 +49,12 @@ class StrategyGenerator(ABC):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.BUFFER
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec]):
def get_sharding_strategy(
self,
name: str,
sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec],
):
"""
A factory method to produce a ShardingStrategy object.
......@@ -80,24 +84,28 @@ class StrategyGenerator(ABC):
op_data = self.op_data[op_data_name]
def _to_sharding_spec(
data: any, logical_shape: any,
dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
data: any, logical_shape: any, dim_partition_dict: Dict[int, List[int]]
) -> Union[ShardingSpec, List[ShardingSpec], None]:
"""
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
if isinstance(data, torch.Tensor):
dim_size = len(logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict)
sharding_spec = ShardingSpec(
device_mesh=self.device_mesh,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict,
)
return sharding_spec
elif isinstance(data, (list, tuple)):
sharding_spec = []
for data_element, logical_shape_element, dim_partition_dict_element in zip(
data, logical_shape, dim_partition_dict):
data, logical_shape, dim_partition_dict
):
sharding_spec.append(
_to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
_to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)
)
return sharding_spec
else:
return None
......@@ -116,31 +124,41 @@ class StrategyGenerator(ABC):
results[op_data] = v
return results
def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]]):
def get_communication_spec(
self,
sharding_spec: ShardingSpec,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
):
"""
A factory method to produce a CommSpec object.
"""
return CommSpec(comm_pattern=communication_pattern,
sharding_spec=sharding_spec,
logical_process_axis=logical_process_axis)
def get_communication_action(self,
sharding_spec: ShardingSpec,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1,
key_for_kwarg: any = None) -> CommAction:
return CommSpec(
comm_pattern=communication_pattern, sharding_spec=sharding_spec, logical_process_axis=logical_process_axis
)
def get_communication_action(
self,
sharding_spec: ShardingSpec,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1,
key_for_kwarg: any = None,
) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec,
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis),
comm_type=comm_type,
arg_index=arg_index,
key_for_kwarg=key_for_kwarg)
return CommAction(
comm_spec=self.get_communication_spec(
sharding_spec=sharding_spec,
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis,
),
comm_type=comm_type,
arg_index=arg_index,
key_for_kwarg=key_for_kwarg,
)
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
......@@ -155,9 +173,9 @@ class StrategyGenerator(ABC):
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
comm_cost.fwd += num_ele_in_comm['forward']
comm_cost.bwd += num_ele_in_comm['backward']
comm_cost.total += num_ele_in_comm['total']
comm_cost.fwd += num_ele_in_comm["forward"]
comm_cost.bwd += num_ele_in_comm["backward"]
comm_cost.total += num_ele_in_comm["total"]
# check if communication action exists
# if so, loop over each action and compute the cost of each action
......@@ -169,8 +187,8 @@ class StrategyGenerator(ABC):
# this condition branch will be removed after all the handler updated.
comm_spec = comm_action
if isinstance(comm_spec, dict):
src_spec = comm_spec['src_spec']
tgt_spec = comm_spec['tgt_spec']
src_spec = comm_spec["src_spec"]
tgt_spec = comm_spec["tgt_spec"]
shape_consistency_manager = ShapeConsistencyManager()
_, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec)
for comm_spec_ in comm_action_sequence:
......@@ -187,14 +205,12 @@ class StrategyGenerator(ABC):
"""
Customize this method to compute the computation flops.
"""
pass
@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the memory cost in bytes.
"""
pass
def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str):
"""
......@@ -212,13 +228,14 @@ class StrategyGenerator(ABC):
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = getattr(meta_data, 'dtype')
dtype = getattr(meta_data, "dtype")
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return num_elements * size_per_elem_bytes
if isinstance(op_data.data, tuple):
assert isinstance(strategy.sharding_specs[op_data], list), \
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
assert isinstance(
strategy.sharding_specs[op_data], list
), "sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple."
total_bytes = 0
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
meta_data = op_data.data[index]
......@@ -270,7 +287,6 @@ class StrategyGenerator(ABC):
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass
class FollowingStrategyGenerator(StrategyGenerator):
......@@ -280,8 +296,9 @@ class FollowingStrategyGenerator(StrategyGenerator):
TODO: remove the original strategy_generator.py after refactoring
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_node: Node):
def __init__(
self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_node: Node
):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
self.predecessor_node = predecessor_node
......@@ -292,7 +309,8 @@ class OutputStrategyGenerator(StrategyGenerator):
OutputStrategyGenerator is used to generate the sharding strategies for Output Node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node]):
def __init__(
self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_nodes: List[Node]
):
super().__init__(operation_data_mapping, device_mesh)
self.predecessor_nodes = predecessor_nodes
......@@ -4,22 +4,9 @@ from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['SumGenerator']
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
__all__ = ["SumGenerator"]
class SumGenerator(FollowingStrategyGenerator):
......@@ -31,24 +18,24 @@ class SumGenerator(FollowingStrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
compute_cost = TrainCycleItem(fwd=input_size_product,
bwd=output_size_product,
total=input_size_product + output_size_product)
compute_cost = TrainCycleItem(
fwd=input_size_product, bwd=output_size_product, total=input_size_product + output_size_product
)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -66,8 +53,9 @@ class SumGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
......@@ -78,7 +66,7 @@ class SumGenerator(FollowingStrategyGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
sum_dims, sum_mapping_dict = self.op_data['sum_info'].data
sum_dims, sum_mapping_dict = self.op_data["sum_info"].data
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# among all the shard groups
......@@ -90,7 +78,7 @@ class SumGenerator(FollowingStrategyGenerator):
elif dim in sum_mapping_dict:
dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim]
else:
raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims')
raise RuntimeError(f"dim {dim} is not in sum_mapping_dict or sum_dims")
for dim in recover_dims:
dim_partition_dict_for_input.pop(dim)
......@@ -105,9 +93,11 @@ class SumGenerator(FollowingStrategyGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from .strategy_generator import StrategyGenerator
__all__ = ['TensorConstructorGenerator']
__all__ = ["TensorConstructorGenerator"]
class TensorConstructorGenerator(StrategyGenerator):
......@@ -30,10 +21,10 @@ class TensorConstructorGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
"""
forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = input + output
......@@ -57,11 +48,13 @@ class TensorConstructorGenerator(StrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Tensor Constructor'
name = "Replica Tensor Constructor"
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -5,7 +5,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost,
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['UnaryElementwiseGenerator']
__all__ = ["UnaryElementwiseGenerator"]
class UnaryElementwiseGenerator(FollowingStrategyGenerator):
......@@ -21,12 +21,12 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -44,8 +44,9 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
......@@ -69,9 +70,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment