Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
from .batch_norm_generator import BatchNormStrategyGenerator
from .binary_elementwise_generator import BinaryElementwiseStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
from .getitem_generator import (GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
from .embedding_generator import EmbeddingStrategyGenerator
from .getattr_generator import GetattrGenerator
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
from .layer_norm_generator import LayerNormGenerator
from .matmul_strategy_generator import (BatchedMatMulStrategyGenerator, DotProductStrategyGenerator,
LinearProjectionStrategyGenerator, MatVecStrategyGenerator)
from .matmul_strategy_generator import (
BatchedMatMulStrategyGenerator,
DotProductStrategyGenerator,
LinearProjectionStrategyGenerator,
MatVecStrategyGenerator,
)
from .normal_pooling_generator import NormalPoolStrategyGenerator
from .output_generator import OutputGenerator
from .placeholder_generator import PlaceholderGenerator
from .reshape_generator import ReshapeGenerator
from .softmax_generator import SoftmaxGenerator
from .strategy_generator import StrategyGenerator
from .sum_generator import SumGenerator
from .tensor_constructor_generator import TensorConstructorGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator
from .where_generator import WhereGenerator
......@@ -17,5 +27,6 @@ __all__ = [
'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator',
'ReshapeGenerator', 'NormalPoolStrategyGenerator'
'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator',
'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator'
]
......@@ -3,7 +3,13 @@ import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
......@@ -98,6 +104,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_mapping = {
......@@ -129,6 +136,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
......@@ -160,6 +168,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_mapping = {
......@@ -181,6 +190,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_mapping = {
......@@ -204,17 +214,21 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
logical_process_axis=mesh_dim_0,
comm_type=CommType.IMPLICIT)
communication_action_mapping = {"output": output_comm_spec}
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_mapping = {
......@@ -238,17 +252,21 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.IMPLICIT)
communication_action_mapping = {"output": output_comm_spec}
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_mapping = {
......@@ -282,12 +300,15 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0])
logical_process_axis=[mesh_dim_0],
comm_type=CommType.IMPLICIT)
communication_action_mapping = {"output": output_comm_spec}
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -316,14 +337,14 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# TODO: The strategies below should be uncommented after runtime
# passes ready.
# SR = SR x R WITH SYNC_BN
# strategy_list.append(self.split_input_batch(0))
# strategy_list.append(self.split_input_batch(1))
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
# SS = SS x S WITH SYNC_BN
# strategy_list.append(self.split_input_both_dim(0, 1))
# strategy_list.append(self.split_input_both_dim(1, 0))
strategy_list.append(self.split_input_both_dim(0, 1))
strategy_list.append(self.split_input_both_dim(1, 0))
# S01R = S01R x R WITH SYNC_BN
# strategy_list.append(self.split_input_batch_1d(0, 1))
strategy_list.append(self.split_input_batch_1d(0, 1))
return strategy_list
import operator
from functools import reduce
from typing import List
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import StrategyGenerator
__all__ = ['BinaryElementwiseStrategyGenerator']
class BinaryElementwiseStrategyGenerator(StrategyGenerator):
"""
An BinaryElementwiseStrategyGenerator is a node handler which deals with elementwise operations
which have two operands and broadcasting occurs such as torch.add.
The logical shape for this operation will be `input <op> other`.
"""
def validate(self) -> bool:
assert len(self.op_data) == 3, \
f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}'
for name, op_data in self.op_data.items():
if not isinstance(op_data.data, (torch.Tensor, int, float)):
raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.')
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
# since elementwise ops are not compute-intensive,
# we approximate the backward compute cost
# to be twice the fwd compute cost
fwd_compute_cost = reduce(operator.mul, shape)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# all input, output and outputs have the same shape
shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
# compute fwd memory cost in bytes
# as the elementwise ops are not memory-intensive
# we approximate the fwd memroy cost to be the output
# and the backward memory cost to be grad of input and other
input_bytes = self._compute_size_in_bytes(strategy, 'input')
other_bytes = self._compute_size_in_bytes(strategy, 'other')
output_bytes = self._compute_size_in_bytes(strategy, 'output')
fwd_memory_cost = MemoryCost(activation=output_bytes)
bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes)
total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
dim_size = len(self.op_data['output'].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
dim_partition_list.extend(sharding_list_2d)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
dim_partition_list.extend(sharding_list_1d_on_dim_0)
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
dim_partition_list.extend(sharding_list_1d_on_dim_1)
# add empty dict for fully replicated case
dim_partition_list.append({})
# sharding strategy bookkeeping
strategy_list = []
# convert these dim partition dict to sharding strategy
for dim_partition_dict in dim_partition_list:
dim_partition_dict_mapping = dict(input=dim_partition_dict,
other=dim_partition_dict,
output=dim_partition_dict)
try:
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
# get name
sharding_seq = sharding_spec_mapping['input'].sharding_sequence
name = f'{sharding_seq} = {sharding_seq} <binary-elementwise-op> {sharding_seq}'
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
return strategy_list
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = self.enumerate_all_possible_output(0, 1)
return strategy_list
......@@ -4,7 +4,6 @@ import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
......@@ -12,10 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
......@@ -135,7 +131,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE)
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
......@@ -144,14 +141,31 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"):
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
......@@ -183,14 +197,31 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"):
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
......@@ -223,8 +254,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
......@@ -234,14 +264,29 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"):
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
......@@ -277,12 +322,11 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
......@@ -316,8 +360,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
......@@ -351,7 +394,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
......@@ -404,14 +448,30 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"):
if self.has_bias:
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name,
......@@ -441,8 +501,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
......
import copy
import operator
import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
class EmbeddingStrategyGenerator(StrategyGenerator):
"""
EmbeddingStrategyGenerator is a generic class to generate strategies for nn.Embedding or F.embedding.
The operation data is defined as `output = input x other`.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
Compute the computation cost per device with this specific strategy.
Note: The computation cost for the embedding handler is estimated as dense computing now.
It may not be accurate.
'''
# TODO: estimate the embedding computation cost as sparse operation
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
other_size_product = reduce(operator.mul, sharded_other_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
forward_compute_cost = input_size_product * other_size_product
backward_activation_cost = other_size_product * output_size_product / sharded_output_shape[-1]
backward_weight_cost = input_size_product * other_size_product
backward_compute_cost = backward_weight_cost + backward_activation_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def non_split(self):
name = f'RR = R x RR'
dim_partition_dict_mapping = {
"input": {},
"other": {},
"output": {},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@ignore_sharding_exception
def split_input(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {},
"output": {
0: [mesh_dim_0],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
},
"other": {
1: [mesh_dim_1],
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
communication_action_mapping = {}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_embedding_dim(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}'
dim_partition_dict_mapping = {
"input": {},
"other": {
1: [mesh_dim_0],
},
"output": {
1: [mesh_dim_0],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {},
"other": {
1: [mesh_dim_0, mesh_dim_1],
},
"output": {
1: [mesh_dim_0, mesh_dim_1],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
# RR= R x RR
strategies.append(self.non_split())
# SR = S x RR
strategies.append(self.split_input(0))
strategies.append(self.split_input(1))
# SS = S x RS
strategies.append(self.split_input_and_embedding_dim(0, 1))
strategies.append(self.split_input_and_embedding_dim(1, 0))
# S01R = S01 x RR
strategies.append(self.split_1d_parallel_on_input(0, 1))
# RS = R x RS
strategies.append(self.split_embedding_dim(0))
strategies.append(self.split_embedding_dim(1))
# RS01 = R x RS01
strategies.append(self.split_1d_parallel_on_embedding_dim(0, 1))
return strategies
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import StrategyGenerator
__all__ = ['GetattrGenerator']
class GetattrGenerator(StrategyGenerator):
"""
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
dim_size = len(self.op_data['output'].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
dim_partition_list.extend(sharding_list_2d)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
dim_partition_list.extend(sharding_list_1d_on_dim_0)
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
dim_partition_list.extend(sharding_list_1d_on_dim_1)
# add empty dict for fully replicated case
dim_partition_list.append({})
# sharding strategy bookkeeping
strategy_list = []
# convert these dim partition dict to sharding strategy
for dim_partition_dict in dim_partition_list:
dim_partition_dict_mapping = dict(output=dim_partition_dict)
try:
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
# get name
name = f"get_attr {sharding_spec_mapping['output'].sharding_sequence}"
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
return strategy_list
def collate_strategies(self) -> List[ShardingStrategy]:
return self.enumerate_all_possible_output(0, 1)
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.logging import get_dist_logger
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import FollowingStrategyGenerator
......@@ -64,37 +71,61 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for strategy in self.predecessor_node.strategies_vector:
getitem_index = self.op_data['index'].data
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
try:
logger = get_dist_logger()
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
dim_partition_dict_for_input = copy.deepcopy(
strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict)
int_index = False
if isinstance(getitem_index, int):
int_index = True
getitem_dims = [
0,
]
shift_length = 1
elif isinstance(getitem_index, slice):
getitem_dims = [
0,
]
else:
getitem_dims = [i for i in range(len(getitem_index))]
if isinstance(getitem_index[0], int):
int_index = True
shift_length = len(getitem_index)
gather_dims = []
for dim in getitem_dims:
if dim in dim_partition_dict_for_input:
gather_dims.append(dim)
for dim in gather_dims:
dim_partition_dict_for_input.pop(dim)
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
gather_input = 0 in dim_partition_dict_for_input
if gather_input:
logical_process_axis = dim_partition_dict_for_output.pop(0)
if int_index:
shift_dim_partition_dict_for_output = {}
for dim, mesh_dim_list in dim_partition_dict_for_output.items():
shift_dim_partition_dict_for_output[dim - 1] = mesh_dim_list
shift_dim_partition_dict_for_output[dim - shift_length] = mesh_dim_list
dim_partition_dict_for_output = shift_dim_partition_dict_for_output
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
if gather_input:
input_communication_spec = self.get_communication_spec(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=logical_process_axis)
communication_action_mapping["input"] = input_communication_spec
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
except ShardingSpecException as e:
logger.debug(e)
continue
strategy_list.append(strategy)
for strategy in strategy_list:
......@@ -114,7 +145,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
strategy_list = []
index = self.op_data["index"].data
for strategy in self.predecessor_node.strategies_vector:
for strategy_index, strategy in enumerate(self.predecessor_node.strategies_vector):
# the sharding spec for input in this case is a tuple of ShardingSpec.
sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict
......@@ -125,8 +156,11 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
sharding_spec_mapping["input"] = sharding_spec_for_input
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
input_sharding_info = f"get the {index} element from ("
for sharding_spec in sharding_spec_for_input:
input_sharding_info += f'{sharding_spec.sharding_sequence}, '
input_sharding_info += ")"
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......
......@@ -3,9 +3,17 @@ import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
......@@ -87,6 +95,7 @@ class LayerNormGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {
"input": dim_partition,
......@@ -107,18 +116,20 @@ class LayerNormGenerator(StrategyGenerator):
total_mesh_dim_list = total_mesh_dim_list[0]
communication_action_mapping = {}
other_comm_spec = self.get_communication_spec(
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["other"] = other_comm_spec
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["bias"] = bias_comm_spec
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -142,6 +153,7 @@ class LayerNormGenerator(StrategyGenerator):
strategy_list.append(strategy)
return strategy_list
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_mapping = {
......
import operator
from ast import arg
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
......@@ -54,12 +60,13 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = sharded_input_shape * 2
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
......@@ -69,6 +76,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}'
......@@ -77,16 +85,17 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {"output": output_comm_spec}
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# do not split dimensions for dot product
......@@ -106,38 +115,86 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
@ignore_sharding_exception
def no_split(self):
name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@ignore_sharding_exception
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}}
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {},
"output": {
0: [mesh_dim]
},
}
if self.has_bias:
dim_partition_dict['bias'] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
other_comm_spec = self.get_communication_spec(
communication_action_mapping = {}
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
bias_comm_spec = self.get_communication_spec(
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=2)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# no split
......@@ -152,6 +209,10 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C = AB
# C: [M, N], A: [M, P], B: [P, N]
......@@ -202,6 +263,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# RR = RR x RR
strategies.append(self.non_split())
return strategies
@ignore_sharding_exception
......@@ -215,36 +279,66 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
"other": {
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
other_comm_spec = self.get_communication_spec(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_spec
communication_action_mapping['other'] = other_comm_spec
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['input'] = input_comm_action
communication_action_mapping['other'] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -269,28 +363,61 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
0: [mesh_dim_0]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action mapping
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
communication_action_mapping['input'] = input_comm_spec
communication_action_mapping['output'] = output_comm_spec
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping['bias'] = bias_comm_spec
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
communication_action_mapping['output'] = output_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -316,20 +443,27 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
-1: [mesh_dim_1]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
input_comm_spec = self.get_communication_spec(
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping["input"] = input_comm_spec
communication_action_mapping['output'] = output_comm_spec
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping["input"] = input_comm_action
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
......@@ -349,17 +483,19 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim)
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_spec
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
......@@ -381,17 +517,20 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
-1: [mesh_dim]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_spec
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
......@@ -410,22 +549,52 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
0: [mesh_dim_0, mesh_dim_1]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
other_comm_spec = self.get_communication_spec(
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['other'] = other_comm_spec
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
......@@ -445,15 +614,19 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['output'] = output_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -476,15 +649,43 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
-1: [mesh_dim_0, mesh_dim_1]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['input'] = input_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {},
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -500,10 +701,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
# check if bias has the same a valid dim
has_bias = "bias" in self.op_data
if has_bias:
if self.has_bias:
bias_data = self.op_data['bias']
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
......@@ -516,8 +714,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
[b, i, k] x [b, k, j] -> [b, i, j]
The bias term is considered to have a 2D logical shape.
Note: This class will be used to generate strategies for torch.bmm
and torch.addbmm. However, the result of torch.addbmm is not correct,
some extra runtime apply actions are required to keep numerical correctness.
"""
# TODO: torch.addbmm correctness issue need to be fixed.
def __init__(self, *args, **kwargs):
self.squeeze_batch_dim = False
super().__init__(*args, **kwargs)
......@@ -537,7 +740,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
if 'bias' in self.op_data:
bias_op_data = self.op_data['bias']
......@@ -566,16 +769,16 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
print(sharding_spec_mapping)
# get communication actions
communication_action_mapping = {}
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
......@@ -602,11 +805,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -637,18 +842,24 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
other_comm_spec = self.get_communication_spec(
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping['other'] = other_comm_spec
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
# for addbmm case, other is the third argument instead of second.
communication_action_mapping['other'].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -679,18 +890,23 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping['input'] = input_comm_spec
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
communication_action_mapping['bias'] = bias_comm_action
# for addbmm case, other is the second argument instead of first.
communication_action_mapping['input'].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -702,11 +918,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
1: [mesh_dim_1]
},
"bias": {},
"output": {
......@@ -719,18 +935,21 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping['output'] = output_comm_spec
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -771,6 +990,5 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
strategy_list.append(self.split_two_batch_dim(1, 0))
return strategy_list
......@@ -3,9 +3,12 @@ import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from .strategy_generator import StrategyGenerator
......@@ -79,6 +82,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}
......
from typing import List
from typing import Dict, List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import OutputStrategyGenerator
......@@ -12,6 +20,11 @@ class OutputGenerator(OutputStrategyGenerator):
OutputGenerator is a generic class to generate strategies for Output Node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node], output_option: str):
super().__init__(operation_data_mapping, device_mesh, predecessor_nodes)
self.output_option = output_option
def validate(self) -> bool:
return super().validate()
......@@ -32,21 +45,77 @@ class OutputGenerator(OutputStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
def replica_strategy(self) -> List[ShardingStrategy]:
"""
Generate replica strategy for output node.
"""
dim_partition_dict_mapping = {}
dim_partition_dict_for_output = []
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
if isinstance(self.op_data[mapping_name].data, (tuple, list)):
dim_partition_dict_for_input = [{} for _ in range(len(self.op_data[mapping_name].data))]
else:
dim_partition_dict_for_input = {}
dim_partition_dict_mapping[mapping_name] = dim_partition_dict_for_input
dim_partition_dict_for_output.append(dim_partition_dict_for_input)
if len(dim_partition_dict_for_output) == 1:
dim_partition_dict_for_output = dim_partition_dict_for_output[0]
else:
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Output'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]:
"""
Generate distributed strategy for output node.
"""
# TODO: need to take care of the case when the first element of output only need to be sharded.
output_op_data = self.op_data['output']
if isinstance(output_op_data.data, tuple):
length = len(output_op_data.data)
dim_partition_dict_mapping = {
"output": [{
0: mesh_list
}] * length,
}
else:
dim_partition_dict_mapping = {
"output": {},
"output": {
0: mesh_list
},
}
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
dim_partition_dict_mapping[mapping_name] = {}
dim_partition_dict_mapping[mapping_name] = {0: mesh_list}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Output'
name = 'Distributed Output'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
mesh_list = [0, 1]
if self.output_option == 'replicated':
strategy_list.append(self.replica_strategy())
elif self.output_option == 'distributed':
strategy_list.append(self.distributed_strategy(mesh_list))
return [strategy]
return strategy_list
from typing import List
from typing import Dict, List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import StrategyGenerator
......@@ -12,6 +18,11 @@ class PlaceholderGenerator(StrategyGenerator):
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
placeholder_option: str):
super().__init__(operation_data_mapping, device_mesh)
self.placeholder_option = placeholder_option
def validate(self) -> bool:
return super().validate()
......@@ -37,7 +48,10 @@ class PlaceholderGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
def replica_placeholder(self) -> ShardingStrategy:
"""
Generate replica strategy for placeholder node.
"""
dim_partition_dict_mapping = {
"output": {},
}
......@@ -50,4 +64,37 @@ class PlaceholderGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return [strategy]
return strategy
def distributed_placeholder(self, mesh_list) -> ShardingStrategy:
"""
Generate distributed strategy for placeholder node.
"""
dim_partition_dict_mapping = {
"output": {
0: mesh_list
},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Distributed Placeholder'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
if self.placeholder_option == 'distributed':
mesh_list = [0, 1]
distributed_strategy = self.distributed_placeholder(mesh_list)
strategy_list.append(distributed_strategy)
else:
assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported'
replicated_strategy = self.replica_placeholder()
strategy_list.append(replicated_strategy)
return strategy_list
......@@ -96,7 +96,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
else:
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
......@@ -104,6 +104,10 @@ class ReshapeGenerator(FollowingStrategyGenerator):
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
__all__ = ['SoftmaxGenerator']
class SoftmaxGenerator(FollowingStrategyGenerator):
"""
SoftmaxGenerator is used to generate strategies for torch.nn.Softmax or F.softmax.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
Compute the computation cost per device with this specific strategy.
'''
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
forward_compute_cost = output_size_product * 2
backward_compute_cost = input_size_product
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
softmax_dim = self.op_data['softmax_dim'].data
if softmax_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(softmax_dim)
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
......@@ -17,6 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.utils import convert_dim_partition_dict
class StrategyGenerator(ABC):
......@@ -67,21 +68,41 @@ class StrategyGenerator(ABC):
Args:
mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary.
Notes:
The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data.
However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as
list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned.
"""
results = {}
for op_data_name, dim_partition_dict in mapping.items():
if op_data_name in self.op_data:
op_data = self.op_data[op_data_name]
if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
sharding_spec = []
for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict):
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=output.shape,
dim_partition_dict=dim_partition_dict_element)
else:
def _to_sharding_spec(
data: any, logical_shape: any,
dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
"""
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
if isinstance(data, torch.Tensor):
dim_size = len(logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict)
return sharding_spec
elif isinstance(data, (list, tuple)):
sharding_spec = []
for data_element, logical_shape_element, dim_partition_dict_element in zip(
data, logical_shape, dim_partition_dict):
sharding_spec.append(
_to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
return sharding_spec
else:
return None
sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict)
results[op_data_name] = sharding_spec
return results
......@@ -109,7 +130,8 @@ class StrategyGenerator(ABC):
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1) -> CommAction:
arg_index: int = -1,
key_for_kwarg: any = None) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
......@@ -117,7 +139,8 @@ class StrategyGenerator(ABC):
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis),
comm_type=comm_type,
arg_index=arg_index)
arg_index=arg_index,
key_for_kwarg=key_for_kwarg)
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
......@@ -180,13 +203,40 @@ class StrategyGenerator(ABC):
Args:
strategy (ShardingStrategy): the ShardingStrategy generated.
key (str): the name of the operation data defined by the generator.
"""
op_data = self.op_data[key]
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
dtype = self.op_data[key].data.dtype
def _compute_size_in_bytes_helper(sharding_spec, meta_data):
sharded_shape = sharding_spec.get_sharded_shape_per_device()
if len(sharded_shape) == 0:
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = getattr(meta_data, 'dtype')
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
return num_elements * size_per_elem_bytes
if isinstance(op_data.data, tuple):
assert isinstance(strategy.sharding_specs[op_data], list), \
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
total_bytes = 0
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
meta_data = op_data.data[index]
if isinstance(meta_data, torch.Tensor):
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
else:
# if meta_data is not a tensor, we count the memroy as 0
element_bytes = 0
total_bytes += element_bytes
else:
if isinstance(op_data.data, torch.Tensor):
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
else:
# if op_data.data is not a tensor, we count the memroy as 0
total_bytes = 0
return total_bytes
def generate(self) -> List[ShardingStrategy]:
"""
......@@ -244,6 +294,5 @@ class OutputStrategyGenerator(StrategyGenerator):
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node]):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
super().__init__(operation_data_mapping, device_mesh)
self.predecessor_nodes = predecessor_nodes
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['SumGenerator']
class SumGenerator(FollowingStrategyGenerator):
"""
SumGenerator deals with the sharding strategies of torch.sum op.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
compute_cost = TrainCycleItem(fwd=input_size_product,
bwd=output_size_product,
total=input_size_product + output_size_product)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
sum_dims, sum_mapping_dict = self.op_data['sum_info'].data
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# among all the shard groups
recover_dims = []
dim_partition_dict_for_output = {}
for dim in dim_partition_dict_for_input:
if dim in sum_dims:
recover_dims.append(dim)
elif dim in sum_mapping_dict:
dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim]
else:
raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims')
for dim in recover_dims:
dim_partition_dict_for_input.pop(dim)
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
from .strategy_generator import StrategyGenerator
__all__ = ['TensorConstructorGenerator']
class TensorConstructorGenerator(StrategyGenerator):
"""
TensorConstructorGenerator which deals with
the sharding strategies for tensor constructor operation, such as torch.arange.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
dim_partition_dict_mapping = {
"output": {},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Tensor Constructor'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from .strategy_generator import StrategyGenerator
......@@ -50,6 +53,7 @@ class WhereGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {
"condition": dim_partition,
......
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, SumGenerator
__all__ = ['SumHandler']
@operator_registry.register(torch.Tensor.sum)
@operator_registry.register(torch.sum)
class SumHandler(NodeHandler):
"""
A SumHandler which deals with the sharding strategies for torch.sum or torch.Tensor.sum.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SumGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
if len(self.node.args) > 1:
sum_dims = self.node.args[1]
else:
sum_dims = tuple(range(self.node.args[0]._meta_data.dim()))
if isinstance(sum_dims, int):
sum_dims = (sum_dims,)
# recover negative value to positive
num_dims = self.node.args[0]._meta_data.dim()
for i in range(len(sum_dims)):
if sum_dims[i] < 0:
sum_dims[i] += num_dims
# mapping the input dims to output dims
# For examples:
# input: torch.rand(2, 3, 4, 5)
# output: torch.sum(input, (0, 2))
# sum_mapping_dict = {1: 0, 3: 1}
# sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input
# sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input
sum_mapping_dict = {}
if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']:
for i in range(num_dims):
sum_mapping_dict.update({i: i})
else:
output_index = 0
for i in range(num_dims):
if i not in sum_dims:
sum_mapping_dict.update({i: output_index})
output_index += 1
assert output_index == self.node._meta_data.dim()
sum_info = (sum_dims, sum_mapping_dict)
physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"sum_info": physical_shape_operand,
"output": physical_output_operand
}
return mapping
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator
from .strategy.tensor_constructor_generator import TensorConstructorGenerator
__all__ = ['TensorConstructorHandler']
@operator_registry.register(torch.arange)
class TensorConstructorHandler(NodeHandler):
"""
A TensorConstructorHandler which deals with the sharding strategies for tensor constructor operations, such as torch.arange.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(TensorConstructorGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {"output": physical_output_operand}
return mapping
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment