Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
......@@ -2,8 +2,10 @@ from typing import Dict, List, Union
import torch
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
......@@ -91,7 +93,15 @@ class AddBMMFunctionHandler(NodeHandler):
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape,
bias_physical_shape)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=bias_op_data,
sharding_spec=bias_sharding_spec)
strategy.communication_actions[bias_op_data] = comm_action
return strategy
......@@ -3,9 +3,9 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from ..utils import transpose_partition_dim
from .node_handler import ModuleHandler, NodeHandler
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
......@@ -15,7 +15,7 @@ __all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
@operator_registry.register(torch.nn.Conv1d)
@operator_registry.register(torch.nn.Conv2d)
@operator_registry.register(torch.nn.Conv3d)
class ConvModuleHandler(ModuleHandler):
class ConvModuleHandler(MetaInfoModuleHandler):
"""
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
"""
......@@ -63,7 +63,7 @@ class ConvModuleHandler(ModuleHandler):
@operator_registry.register(F.conv1d)
@operator_registry.register(F.conv2d)
@operator_registry.register(F.conv3d)
class ConvFunctionHandler(NodeHandler):
class ConvFunctionHandler(MetaInfoNodeHandler):
"""
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
"""
......
from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import update_partition_dim
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import EmbeddingStrategyGenerator, StrategyGenerator
__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler']
def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str,
output_name: str) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output
of the embedding operation.
Args:
strategy (ShardingStrategy): the logical strategy generated by the strategy generator.
input_name (str): the name of the OperationData object for the input.
output_name (str): the name of the OperationData object for the output.
"""
# the result will be a list of strategies
sharding_strategies = []
# get operation data
input_op_data = strategy.get_op_data_by_name(input_name)
output_op_data = strategy.get_op_data_by_name(output_name)
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name)
# recover the last logical dimension to physical dimension
last_logical_output_dims = len(output_op_data.logical_shape) - 1
last_physical_output_dims = output_op_data.data.dim() - 1
# get logger for debug message
logger = get_dist_logger()
# For the input of the embedding operation, it can be multi-dimensional. The sharding spec is only generated for
# logical 1D non-matrix dimension, the logical non-matrix dimension can belong to the 0th to Nth dimension of the
# physical input shape. Thus, we enumerate to get all possible cases.
if input_sharding_spec.dim_partition_dict:
# if bool(input_sharding_spec.dim_partition_dict), it means that the
# the generated sharding strategy does shard the non-matrix dimension,
# in this case, we need to do enumeration
num_input_dims = input_op_data.data.dim()
for i in range(num_input_dims):
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {0: i}
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping=dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
strategy_copy.name = f'{strategy.name}_{i}'
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
# in this case, we don't need to do enumeration
# but instead, we still need to convert the logical shape to physical shape
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
physical_shape=input_op_data.data.shape,
inplace=True)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {}
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping=dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(strategy_copy)
return sharding_strategies
@operator_registry.register(torch.nn.Embedding)
class EmbeddingModuleHandler(ModuleHandler):
"""
A EmbeddingModuleHandler which deals with the sharding strategies for nn.Embedding module.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# In nn.Embedding operation, all the dimensions of input will be treated as the batch dimension,
# and then the sharding spec will be generated based on the logical 1D tensor.
# After that, the logical sharding info will be enumerated among all the physical dimensions.
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=input_meta_data,
logical_shape=input_logical_shape)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'])
# Same as input, in nn.Embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
# on the logical 2D tensor.
# After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions.
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=output_logical_shape)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
"""
Convert the sharding spec from the logical shape to the physical shape.
"""
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
input_name=str(
self.node.args[0]),
output_name=str(self.node))
return strategies
@operator_registry.register(F.embedding)
class EmbeddingFunctionHandler(NodeHandler):
"""
A EmbeddingFunctionHandler which deals with the sharding strategies for F.embedding.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# In F.embedding operation, all the dimensions of input will be treated as the batch dimension,
# and then the sharding spec will be generated based on the logical 1D tensor.
# After that, the logical sharding info will be enumerated among all the physical dimensions.
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data)
# Same as input, in F.embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
# on the logical 2D tensor.
# After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions.
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.node._meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping
def post_process(self, strategy: ShardingStrategy):
"""
Convert the sharding spec from the logical shape to the physical shape.
"""
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
input_name=str(
self.node.args[0]),
output_name=str(self.node))
return strategies
from .permute_handler import PermuteHandler
from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator
from .split_handler import SplitHandler
from .transpose_handler import TransposeHandler
from .view_handler import ViewHandler
__all__ = [
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator',
'SplitHandler', 'SplitGenerator'
]
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import PermuteGenerator
__all__ = ['PermuteHandler']
@operator_registry.register(torch.Tensor.permute)
@operator_registry.register(torch.permute)
class PermuteHandler(NodeHandler):
"""
A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(PermuteGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
permute_dims = []
if self.node.op == 'call_method':
# torch.Tensor.permute (input, *dims)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
permute_dims.append(arg._meta_data)
else:
assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.'
permute_dims.append(arg)
else:
# torch.permute (input, dims)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, (tuple, list)):
permute_dims.extend(arg._meta_data)
else:
assert isinstance(
arg,
(tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].'
permute_dims.extend(arg)
num_dims = self.node._meta_data.dim()
for i in range(num_dims):
# recover negative value to positive
if permute_dims[i] < 0:
permute_dims[i] += num_dims
physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"permute_dims": physical_shape_operand,
"output": physical_output_operand
}
return mapping
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
class ReshapeGenerator(FollowingStrategyGenerator):
"""
ReshapeGenerator is the base class for all the reshape operation.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
return super().collate_strategies()
class ViewGenerator(ReshapeGenerator):
"""
ViewGenerator deals with the sharding strategies of view op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
origin_shape = self.op_data['input'].data.shape
tgt_shape = self.op_data['tgt_shape'].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
reshape_mapping_dict)
else:
dim_partition_dict_for_output = {}
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
if keep_sharding_status:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
else:
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
# add comm action for converting input to fully replicated
total_mesh_dim_list = []
for mesh_dim_list in dim_partition_dict_for_input.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
# the total mesh dim list only has one element, so the shard dim has only one element as well.
shard_dim = list(dim_partition_dict_for_input.keys())[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
input_comm_action.comm_spec.shard_dim = shard_dim
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class PermuteGenerator(ReshapeGenerator):
"""
PermuteGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
permute_dims = self.op_data['permute_dims'].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
if permute_dim in dim_partition_dict_for_input:
dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim]
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class TransposeGenerator(ReshapeGenerator):
"""
TransposeGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
transpose_dims = self.op_data['transpose_dims'].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
if dim == dim_0:
dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0]
elif dim == dim_1:
dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1]
else:
dim_partition_dict_for_output[dim] = sharded_dims
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class SplitGenerator(ReshapeGenerator):
"""
SplitGenerator deals with the sharding strategies of split op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
recover_dims = None
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
split_size, split_dim = self.op_data['split_info'].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
dim_partition_dict_for_output = [
copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
]
assert len(dim_partition_dict_for_output) >= 2
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
# add comm action if the input need to be recovered to replica in the split dimension.
if recover_dims:
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(recover_dims) == 1:
recover_dims = recover_dims[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
input_comm_action.comm_spec.shard_dim = split_dim
elif len(recover_dims) >= 2:
# original sharding spec
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import SplitGenerator
__all__ = ['SplitHandler']
@operator_registry.register(torch.Tensor.split)
@operator_registry.register(torch.split)
class SplitHandler(NodeHandler):
"""
A SplitHandler which deals with the sharding strategies for torch.permute or torch.split.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SplitGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
split_size = self.node.args[1]
if len(self.node.args) == 3:
# (input, split_size, split_dim)
split_dim = self.node.args[2]
else:
if self.node.kwargs:
split_dim = self.node.kwargs['dim']
else:
split_dim = 0
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if split_dim < 0:
split_dim += num_dims
split_info = (split_size, split_dim)
physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"split_info": physical_shape_operand,
"output": physical_output_operand
}
return mapping
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import TransposeGenerator
__all__ = ['TransposeHandler']
@operator_registry.register(torch.Tensor.transpose)
@operator_registry.register(torch.transpose)
class TransposeHandler(NodeHandler):
"""
A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(TransposeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
transpose_dims = []
# torch.transpose (input, dim0, dim1)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
transpose_dims.append(arg._meta_data)
else:
transpose_dims.append(arg)
num_dims = self.node._meta_data.dim()
for i in range(2):
# recover negative value to positive
if transpose_dims[i] < 0:
transpose_dims[i] += num_dims
physical_shape_operand = OperationData(name='transpose_dims',
type=OperationDataType.ARG,
data=list(transpose_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"transpose_dims": physical_shape_operand,
"output": physical_output_operand
}
return mapping
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import ViewGenerator
__all__ = ['ViewHandler']
@operator_registry.register(torch.Tensor.reshape)
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.Tensor.view)
class ViewHandler(NodeHandler):
"""
A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
target_shape = self.node._meta_data.shape
physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"tgt_shape": physical_shape_operand,
"output": physical_output_operand
}
return mapping
from typing import Dict, List
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .strategy import GetattrGenerator, StrategyGenerator
__all__ = ['GetattrHandler']
class GetattrHandler(NodeHandler):
"""
A GetattrHandler which deals with the sharding strategies for Getattr Node.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(GetattrGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# There are only two possible types for get_attr node:
# 1. torch.Tensor(torch.nn.Parameters or torch.nn.Buffers)
# 2. torch.nn.Module
# temporarily, we just support first case in Tracer, so we don't have to worry about
# issue related to the node._meta_data type.
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"output": physical_output}
return mapping
......@@ -6,7 +6,7 @@ import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import (StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
__all__ = ['GetItemHandler']
......
......@@ -3,12 +3,16 @@ from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
from colossalai.auto_parallel.tensor_shard.utils import (
check_sharding_spec_validity,
transpose_partition_dim,
update_partition_dim,
)
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
......@@ -28,9 +32,11 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
# switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
transpose_partition_dim(sharding_spec, 0, -1)
assert op_data.logical_shape[0] == op_data.data.shape[1] and \
op_data.logical_shape[1] == op_data.data.shape[0], \
"Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
dim_size = len(op_data.logical_shape)
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
return strategy
......@@ -54,6 +60,23 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
input_op_data = strategy.get_op_data_by_name(input_name)
output_op_data = strategy.get_op_data_by_name(output_name)
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name)
# recover the last logical dimension to physical dimension
last_logical_input_dims = len(input_op_data.logical_shape) - 1
last_logical_output_dims = len(output_op_data.logical_shape) - 1
last_physical_input_dims = input_op_data.data.dim() - 1
last_physical_output_dims = output_op_data.data.dim() - 1
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims}
else:
input_last_dim_mapping = {}
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
output_last_dim_mapping = {}
# get logger for debug message
logger = get_dist_logger()
......@@ -73,14 +96,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
input_dim_mapping = {0: i}
input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True)
output_dim_mapping = {0: i}
output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
strategy_copy.name = f'{strategy.name}_{i}'
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
......@@ -95,12 +125,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
input_dim_mapping = {}
input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True)
output_dim_mapping = {}
output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={},
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(strategy_copy)
......@@ -108,7 +143,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
@operator_registry.register(torch.nn.Linear)
class LinearModuleHandler(ModuleHandler):
class LinearModuleHandler(MetaInfoModuleHandler):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""
......@@ -116,7 +151,8 @@ class LinearModuleHandler(ModuleHandler):
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......@@ -167,15 +203,16 @@ class LinearModuleHandler(ModuleHandler):
@operator_registry.register(F.linear)
class LinearFunctionHandler(NodeHandler):
class LinearFunctionHandler(MetaInfoNodeHandler):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......@@ -198,27 +235,34 @@ class LinearFunctionHandler(NodeHandler):
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.node._meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.node.args[2] is not None:
if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
# check if the other operand is a parameter
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter):
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.args[2]),
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
type=data_type,
data=self.node.args[2]._meta_data)
data=self.node.kwargs["bias"]._meta_data)
mapping['bias'] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
# switch the dimensions of the transposed weight
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
weight_name=str(self.node.args[1]))
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
......
import operator
from abc import ABC, abstractmethod
from copy import deepcopy
from enum import Enum
from functools import reduce
from typing import Dict, List, Union
import torch
from colossalai.auto_parallel.tensor_shard.utils.broadcast import (
BroadcastType,
get_broadcast_dim_info,
get_broadcast_shape,
)
from colossalai.tensor.sharding_spec import ShardingSpecException
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
DotProductStrategyGenerator,
LinearProjectionStrategyGenerator,
MatVecStrategyGenerator,
StrategyGenerator,
)
class MatMulType(Enum):
"""
The MatMulType is categorized into 4 types based on the reference of torch.matmul
in https://pytorch.org/docs/stable/generated/torch.matmul.html.
DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements
MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
DOT = 0
MM = 1
MV = 2
BMM = 3
def get_matmul_type(input_dim: int, other_dim: int):
"""
Determine which type of matmul operation should be executed for the given tensor dimensions.
Args:
input_dim (int): the number of dimensions for the input tenosr
other_dim (int): the number of dimensions for the other tenosr
"""
if input_dim == 1 and other_dim == 1:
matmul_type = MatMulType.DOT
elif input_dim in [1, 2] and other_dim == 2:
matmul_type = MatMulType.MM
elif input_dim == 2 and other_dim == 1:
matmul_type = MatMulType.MV
elif input_dim >= 1 and other_dim >= 1 and (input_dim > 2 or other_dim > 2):
matmul_type = MatMulType.BMM
else:
raise ValueError(
f"The input and other tensors are of {input_dim} and {other_dim} which cannot used to execute matmul operation"
)
return matmul_type
class BmmTransform(ABC):
"""
BmmTransform is an abstraction of the shape conversion between logical and physical operation data
during the strategy generation.
"""
@abstractmethod
def apply(self, shape_mapping: Dict[str, List[int]]):
pass
@abstractmethod
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
pass
class Padder(BmmTransform):
"""
Add padding to the matrix dimensions for batched matrix multiplication.
"""
def __init__(self) -> None:
# keep the padding dim, op_name -> padded_dim
self.padded_dim_mapping = {}
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = deepcopy(shape_mapping)
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
if len(input_shape) == 1:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape.insert(0, 1)
self.padded_dim_mapping['input'] = -2
self.padded_dim_mapping['output'] = -2
elif len(other_shape) == 1:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape = other_shape.append(1)
self.padded_dim_mapping['other'] = -1
self.padded_dim_mapping['output'] = -1
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
input_op_data = op_data_mapping['input']
other_op_data = op_data_mapping['other']
def _remove_padded_dim(key, strategy):
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
tensor_shape = list(sharding_spec.entire_shape)
dim_partition_list = [None] * len(tensor_shape)
# padded dim is a negative number as the padded dim must be a matrix dim
padded_dim = self.padded_dim_mapping[key]
# compute the new dim partition
for tensor_dim, mesh_dims in sharding_spec.dim_partition_dict.items():
dim_partition_list[tensor_dim] = mesh_dims
dim_partition_list.pop(padded_dim)
unpadded_dim_partition_list = {k: v for k, v in enumerate(dim_partition_list) if v is not None}
# compute unpadded tensor shape
tensor_shape.pop(padded_dim)
assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
# update sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
# only one of input and other will be padded
if 'input' in self.padded_dim_mapping:
_remove_padded_dim('input', strategy_copy)
_remove_padded_dim('output', strategy_copy)
elif 'other' in self.padded_dim_mapping:
_remove_padded_dim('other', strategy_copy)
_remove_padded_dim('output', strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
pass
return strategies
class Broadcaster(BmmTransform):
"""
Broadcast the non-matrix dimensions for batched matrix multiplication.
"""
def __init__(self) -> None:
self.broadcast_dim_info = {}
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
# get shapes
input_shape = mapping_copy['input']
other_shape = mapping_copy['other']
# sanity check
assert len(input_shape) > 1 and len(other_shape) > 1
# broadcast the batch dim and record
bcast_non_matrix_dims = get_broadcast_shape(input_shape[:-2], other_shape[:-2])
# store the broadcast dim info
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
self.broadcast_dim_info['input'] = input_broadcast_dim_info
self.broadcast_dim_info['other'] = other_broadcast_dim_info
# create the full logical shape
input_shape = bcast_non_matrix_dims + input_shape[-2:]
other_shape = bcast_non_matrix_dims + other_shape[-2:]
assert len(input_shape) == len(other_shape)
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
# remove sharding on the broadcast dim
def _remove_sharding_on_broadcast_dim(key, strategy):
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
tensor_shape = list(sharding_spec.entire_shape)
for dim_idx, broadcast_type in self.broadcast_dim_info[key].items():
if broadcast_type == BroadcastType.MULTIPLE:
# if the dim is originally 1 and multiplied during broadcast
# we set its sharding to R
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
# the dim 0 of [1, 2, 4] is multiplied to 4
tensor_shape[dim_idx] = 1
elif broadcast_type == BroadcastType.PADDDING:
# if the dim is padded
# we remove its sharding
tensor_shape[dim_idx] = None
tensor_shape_before_broadcast = [dim for dim in tensor_shape if dim is not None]
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape,
physical_shape=tensor_shape_before_broadcast)
strategy.sharding_specs[op_data] = physical_sharding_spec
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
_remove_sharding_on_broadcast_dim('input', strategy_copy)
_remove_sharding_on_broadcast_dim('other', strategy_copy)
strategies.append(strategy_copy)
except ShardingSpecException as e:
pass
return strategies
class Viewer(BmmTransform):
"""
Change the shape of the tensor from N-D to 3D
"""
def __init__(self) -> None:
self.batch_dims_before_view = None
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
self.batch_dims_before_view = list(mapping_copy['input'][:-2])
# get shapes
input_shape = shape_mapping['input']
other_shape = shape_mapping['other']
# view to 3d tensor
assert len(input_shape) >= 3 and len(other_shape) >= 3
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
output_shape = input_shape[:2] + other_shape[2:]
mapping_copy['input'] = input_shape
mapping_copy['other'] = other_shape
mapping_copy['output'] = output_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
# get operation data
def _update_sharding_spec(key, strategy, physical_batch_dim):
"""
Map the logical batch dim to the physical batch dim
"""
op_data = op_data_mapping[key]
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
dim_partition_dict = sharding_spec.dim_partition_dict
entire_shape = sharding_spec.entire_shape
# upddate the dimension index for the matrix dimensions
if 2 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2)
if 1 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1)
# map the logical batch dim to phyiscal batch dim
if 0 in dim_partition_dict:
batch_dim_shard = dim_partition_dict.pop(0)
dim_partition_dict[physical_batch_dim] = batch_dim_shard
# the new shape will be the batch dims + the last 2 matrix dims
shape_before_view = self.batch_dims_before_view + list(entire_shape[-2:])
sharding_spec.__init__(sharding_spec.device_mesh, shape_before_view, dim_partition_dict)
num_batch_dim_before_view = len(self.batch_dims_before_view)
# enumerate all sharding strategies
strategies = []
for i in range(num_batch_dim_before_view):
# create a new strategy
strategy_copy = strategy.clone()
try:
_update_sharding_spec('input', strategy_copy, i)
_update_sharding_spec('other', strategy_copy, i)
_update_sharding_spec('output', strategy_copy, i)
strategies.append(strategy_copy)
except ShardingSpecException as e:
continue
return strategies
def _get_bmm_logical_shape(input_shape, other_shape, transforms):
"""
Compute the logical shapes for BMM operation. BMM has a general representation
[b, i, k] = [b, i, j] x [b, j, k]
The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions
The logical shape for the bmm operands will undergo three stages
1. append/prepend the 1 to the 1D tensor if there is any
2. broadcast the non-matrix dimensions
3. reshape to 3 dimensions
"""
shape_mapping = {'input': input_shape, 'other': other_shape}
for transform in transforms:
shape_mapping = transform.apply(shape_mapping)
input_shape = shape_mapping.get('input', None)
other_shape = shape_mapping.get('other', None)
output_shape = shape_mapping.get('output', None)
return input_shape, other_shape, output_shape
@operator_registry.register(torch.matmul)
@operator_registry.register(torch.Tensor.matmul)
class MatMulHandler(NodeHandler):
"""
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
the operands.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# check which type of operation this matmul will call
self.input_meta_data = self.node.args[0]._meta_data
self.other_meta_data = self.node.args[1]._meta_data
self.output_meta_data = self.node._meta_data
input_dim = self.input_meta_data.dim()
other_dim = self.other_meta_data.dim()
self.matmul_type = get_matmul_type(input_dim, other_dim)
if self.matmul_type == MatMulType.BMM:
# bmm operation can possibly involve padding, broadcasting and view
# these transforms will be used to create logical shape and
# recover physical sharding spec
self.transforms = [Padder(), Broadcaster(), Viewer()]
else:
self.transforms = None
def get_strategy_generator(self) -> List[StrategyGenerator]:
generators = []
op_data_mapping = self.get_operation_data_mapping()
if self.matmul_type == MatMulType.BMM:
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.DOT:
generators.append(DotProductStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MV:
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM:
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
logical_shape_func = {
MatMulType.DOT: self._get_logical_shape_for_dot,
MatMulType.MM: self._get_logical_shape_for_mm,
MatMulType.MV: self._get_logical_shape_for_mv,
MatMulType.BMM: self._get_logical_shape_for_bmm
}
logical_shapes = logical_shape_func[self.matmul_type]()
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
return op_data_mapping
def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_logical_shape):
# convert list to torch.Size
if input_logical_shape:
input_logical_shape = torch.Size(input_logical_shape)
if other_logical_shape:
other_logical_shape = torch.Size(other_logical_shape)
if output_logical_shape:
output_logical_shape = torch.Size(output_logical_shape)
# create op data
input_op_data = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.input_meta_data,
logical_shape=input_logical_shape)
other_op_data = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.other_meta_data,
logical_shape=other_logical_shape)
output_op_data = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.output_meta_data,
logical_shape=output_logical_shape)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
def _get_logical_shape_for_dot(self):
"""
The operands for the dot operation have the same logical shape as the physical shape
"""
return None, None, None
def _get_logical_shape_for_mm(self):
"""
We need to handle the input tensor for a matrix-matrix multiplcation as the input
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
(e.g. [4] -> [1, 4]).
"""
if self.input_meta_data.dim() == 1:
input_logical_shape = [1] + list(self.input_meta_data.shape)
input_logical_shape = torch.Size(input_logical_shape)
else:
input_logical_shape = None
return input_logical_shape, None, None
def _get_logical_shape_for_mv(self):
"""
No broadcasting or dim insertion occurs for matrix-vector operation.
"""
return None, None, None
def _get_logical_shape_for_bmm(self):
input_physical_shape = list(self.input_meta_data.shape)
other_physical_shape = list(self.other_meta_data.shape)
return _get_bmm_logical_shape(input_physical_shape, other_physical_shape, self.transforms)
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
if self.matmul_type in [MatMulType.DOT, MatMulType.MV]:
return strategy
elif self.matmul_type == MatMulType.MM:
if self.input_meta_data.dim() == 1:
# if a 1 is prepended to the input shape (this occurs when input is a 1D tensor)
# we need to remove that dim
input_sharding_spec = strategy.get_sharding_spec_by_name(str(self.node.args[0]))
input_physical_shape = self.node.args[0]._meta_data.shape
dim_partition_dict = input_sharding_spec.dim_partition_dict
# remove the partitioning in the dim 0
if 0 in dim_partition_dict:
dim_partition_dict.pop(0, None)
# move the partitioning in dim 1 to dim 0
if -1 in dim_partition_dict:
shard = dim_partition_dict.pop(-1)
dim_partition_dict[0] = shard
if 1 in dim_partition_dict:
shard = dim_partition_dict.pop(1)
dim_partition_dict[0] = shard
# re-init the sharding spec
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
entire_shape=input_physical_shape,
dim_partition_dict=dim_partition_dict)
return strategy
else:
return strategy
elif self.matmul_type == MatMulType.BMM:
op_data_mapping = self.get_operation_data_mapping()
strategies = [strategy]
# recover the physical sharding spec
for transform in self.transforms[::-1]:
recovered_stragies = []
for strategy_ in strategies:
output = transform.recover(op_data_mapping, strategy_)
if isinstance(output, ShardingStrategy):
recovered_stragies.append(output)
elif isinstance(output, (list, tuple)):
recovered_stragies.extend(output)
else:
raise TypeError(
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
strategies = recovered_stragies
return strategies
from abc import ABC, abstractmethod
from typing import Dict, List, Union
from typing import Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingSpec,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
......@@ -49,7 +52,16 @@ class NodeHandler(ABC):
for node in self.predecessor_node:
node_name = str(node)
# get the current sharding spec generated by this node handler
# we will not compute the resharding costs for the node not counted in the strategy.
# And the node with tuple or list output need to be handled below.
node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()]
if str(node) not in node_in_strategy:
continue
op_data = strategy.get_op_data_by_name(node_name)
current_sharding_spec = strategy.sharding_specs[op_data]
# get the sharding specs for this node generated
# in its own node handler
assert hasattr(node, 'strategies_vector'), \
......@@ -59,27 +71,83 @@ class NodeHandler(ABC):
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
]
# get the current sharding spec generated by this node handler
op_data = strategy.get_op_data_by_name(node_name)
current_sharding_spec = strategy.sharding_specs[op_data]
# create data structrure to store costs
if op_data not in resharding_costs:
if node not in resharding_costs:
resharding_costs[node] = []
def _compute_resharding_cost(
prev_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
List[ShardingSpec]],
data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
"""
This is a helper function to compute the resharding cost for a specific strategy of a node.
"""
if prev_sharding_spec is None:
return TrainCycleItem(fwd=0, bwd=0, total=0)
elif isinstance(prev_sharding_spec, ShardingSpec):
if isinstance(data, torch.Tensor):
dtype = data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
prev_sharding_spec, current_sharding_spec)
resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
bwd=consistency_cost["backward"] * size_per_elem_bytes,
total=consistency_cost["total"] * size_per_elem_bytes)
return resharding_cost
else:
# This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments.
raise ValueError(f'Unsupported data type {type(data)}')
else:
assert isinstance(prev_sharding_spec, (tuple, list)), \
f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
fwd_cost = 0
bwd_cost = 0
total_cost = 0
for index, (prev_sharding_spec_item,
current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
current_sharding_spec)):
item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
data[index])
fwd_cost += item_cost.fwd
bwd_cost += item_cost.bwd
total_cost += item_cost.total
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost)
return resharding_cost
# for each sharding spec generated by the predecessor's node handler
# compute the resharding cost to switch to the sharding spec generated
# by the current node handler
for prev_sharding_spec in prev_sharding_specs:
_, _, resharding_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec,
current_sharding_spec)
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"],
bwd=resharding_cost["backward"],
total=resharding_cost["total"])
resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data)
resharding_costs[node].append(resharding_cost)
strategy.resharding_costs = resharding_costs
return strategy
def get_target_function(self) -> callable:
"""
This function is used to get the target function for the node handler.
The target function is used to analyze the costs of strategies.
"""
if self.node.op in ('placeholder', 'get_attr', 'output'):
return None
if self.node.op == 'call_module':
target = self.node.graph.owning_module.get_submodule(self.node.target)
elif self.node.op == 'call_function':
target = self.node.target
elif self.node.op == 'call_method':
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
else:
raise ValueError(f'Unsupported node type: {self.node.op}')
return target
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
"""
Register different sharding strategies for the current node.
......@@ -151,6 +219,38 @@ class NodeHandler(ABC):
pass
class MetaInfoNodeHandler(NodeHandler):
"""
This is a base class to handle the nodes patched in the meta profiler.
Note: this class will be integrated into the NodeHandler class in the future, after
all the functions are patched.
"""
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
"""
This method is inherited from NodeHandler. It will register the strategies first,
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector
class ModuleHandler(NodeHandler):
def __init__(self, *args, **kwargs) -> None:
......@@ -168,3 +268,35 @@ class ModuleHandler(NodeHandler):
self.module = module
self.named_parameters = named_parameters
self.named_buffers = named_buffers
class MetaInfoModuleHandler(ModuleHandler):
"""
This is a base class to handle the module patched in the meta profiler.
Note: this class will be integrated into the ModuleHandler class in the future, after
all the modules are patched.
"""
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
"""
This method is inherited from NodeHandler. It will register the strategies first,
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector
......@@ -3,7 +3,7 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from .node_handler import MetaInfoModuleHandler, ModuleHandler
from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
......@@ -16,7 +16,7 @@ __all__ = ['NormPoolingHandler']
@operator_registry.register(torch.nn.AvgPool1d)
@operator_registry.register(torch.nn.AvgPool2d)
@operator_registry.register(torch.nn.AvgPool3d)
class NormPoolingHandler(ModuleHandler):
class NormPoolingHandler(MetaInfoModuleHandler):
"""
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
"""
......
......@@ -2,38 +2,51 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from colossalai.device.device_mesh import DeviceMesh
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
from .node_handler import NodeHandler
from .strategy import OutputGenerator, StrategyGenerator
__all__ = ['OuputHandler']
__all__ = ['OutputHandler']
class OuputHandler(NodeHandler):
class OutputHandler(NodeHandler):
"""
A OuputHandler which deals with the sharding strategies for Output Node.
A OutputHandler which deals with the sharding strategies for Output Node.
"""
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
output_option: str) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.output_option = output_option
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node))
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node, self.output_option))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
dummy_output = torch.empty(1,).to("meta")
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=dummy_output)
mapping = {"output": physical_output}
mapping = {}
output_meta_data = []
for index, input_node in enumerate(self.predecessor_node):
if not hasattr(input_node, "_meta_data"):
print(input_node.name)
physical_inputs = OperationData(name=str(input_node),
type=OperationDataType.ARG,
data=input_node._meta_data)
input_meta_data = input_node._meta_data
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
name_key = f'input_{index}'
mapping[name_key] = physical_inputs
output_meta_data.append(input_meta_data)
assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.'
if len(output_meta_data) == 1:
output_meta_data = output_meta_data[0]
else:
output_meta_data = tuple(output_meta_data)
self.node._meta_data = output_meta_data
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping["output"] = physical_output
return mapping
from typing import Dict, List
from ..sharding_strategy import OperationData, OperationDataType
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
from .node_handler import NodeHandler
from .strategy import PlaceholderGenerator, StrategyGenerator
__all__ = ['PlacehodlerHandler']
__all__ = ['PlaceholderHandler']
class PlacehodlerHandler(NodeHandler):
class PlaceholderHandler(NodeHandler):
"""
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node.
A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
"""
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
placeholder_option: str) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.placeholder_option = placeholder_option
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(PlaceholderGenerator(op_data_mapping, self.device_mesh))
generators.append(
PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......
......@@ -8,6 +8,11 @@ class Registry:
def register(self, source):
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
for element in source:
self.store[element] = func
else:
self.store[source] = func
return func
......
......@@ -3,18 +3,17 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import ReshapeGenerator, StrategyGenerator
__all__ = ['ReshapeHandler']
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.permute)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler):
class ReshapeHandler(MetaInfoNodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
......@@ -25,13 +24,47 @@ class ReshapeHandler(NodeHandler):
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def infer_logical_shape(self, data):
"""
This function is used to infer logical shape for operands.
Notes: This function is only used for the operands whose data are not only in type of tensor,
such as tuple of tensor.
"""
if isinstance(data, torch.Tensor):
return data.shape
else:
assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor."
logical_shape = []
for tensor in data:
assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor."
logical_shape.append(tensor.shape)
logical_shape = tuple(logical_shape)
return logical_shape
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
input_logical_shape = self.infer_logical_shape(input_data)
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
type=data_type,
data=input_data,
logical_shape=input_logical_shape)
output_data = self.node._meta_data
output_logical_shape = self.infer_logical_shape(output_data)
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_data,
logical_shape=output_logical_shape)
mapping = {"input": physical_input_operand, "output": physical_output}
......
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import SoftmaxGenerator, StrategyGenerator
__all__ = ['SoftmaxHandler']
@operator_registry.register(torch.nn.Softmax)
@operator_registry.register(torch.nn.functional.softmax)
class SoftmaxHandler(NodeHandler):
"""
A SoftmaxHandler which deals with the sharding strategies for
torch.nn.Softmax or torch.nn.functional.softmax.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SoftmaxGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
softmax_dim = self.node.kwargs['dim']
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if softmax_dim < 0:
softmax_dim += num_dims
physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"softmax_dim": physical_dim_operand,
"output": physical_output_operand
}
return mapping
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment