Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
import copy import copy
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from .strategy_generator import FollowingStrategyGenerator from .strategy_generator import FollowingStrategyGenerator
__all__ = ['UnaryElementwiseGenerator'] __all__ = ["UnaryElementwiseGenerator"]
class UnaryElementwiseGenerator(FollowingStrategyGenerator): class UnaryElementwiseGenerator(FollowingStrategyGenerator):
...@@ -21,12 +21,12 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): ...@@ -21,12 +21,12 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy): def update_memory_cost(self, strategy: ShardingStrategy):
''' """
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' """
forward_size_mapping = { forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"), "input": self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output") "output": self._compute_size_in_bytes(strategy, "output"),
} }
backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping = copy.deepcopy(forward_size_mapping)
...@@ -44,8 +44,9 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): ...@@ -44,8 +44,9 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost # compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, total_mem_cost = MemoryCost(
parameter=fwd_parameter_cost + bwd_parameter_cost) activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
...@@ -69,9 +70,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): ...@@ -69,9 +70,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
# we keep same strategies with different name for node merging, and it will not increase the searching space, # we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name, strategy = self.get_sharding_strategy(
sharding_spec_mapping=sharding_spec_mapping, name=name,
communication_action_mapping=communication_action_mapping) sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy) strategy_list.append(strategy)
return strategy_list return strategy_list
...@@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.utils import ( ...@@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
__all__ = ['WhereGenerator'] __all__ = ["WhereGenerator"]
class WhereGenerator(StrategyGenerator): class WhereGenerator(StrategyGenerator):
...@@ -26,14 +26,14 @@ class WhereGenerator(StrategyGenerator): ...@@ -26,14 +26,14 @@ class WhereGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy): def update_memory_cost(self, strategy: ShardingStrategy):
''' """
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' """
forward_size_mapping = { forward_size_mapping = {
'condition': self._compute_size_in_bytes(strategy, "condition"), "condition": self._compute_size_in_bytes(strategy, "condition"),
'x': self._compute_size_in_bytes(strategy, "x"), "x": self._compute_size_in_bytes(strategy, "x"),
'y': self._compute_size_in_bytes(strategy, "y"), "y": self._compute_size_in_bytes(strategy, "y"),
'output': self._compute_size_in_bytes(strategy, "output") "output": self._compute_size_in_bytes(strategy, "output"),
} }
backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping = copy.deepcopy(forward_size_mapping)
...@@ -59,7 +59,7 @@ class WhereGenerator(StrategyGenerator): ...@@ -59,7 +59,7 @@ class WhereGenerator(StrategyGenerator):
"condition": dim_partition, "condition": dim_partition,
"x": dim_partition, "x": dim_partition,
"y": dim_partition, "y": dim_partition,
"output": dim_partition "output": dim_partition,
} }
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
...@@ -67,9 +67,11 @@ class WhereGenerator(StrategyGenerator): ...@@ -67,9 +67,11 @@ class WhereGenerator(StrategyGenerator):
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}' name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}'
communication_action_mapping = {} communication_action_mapping = {}
strategy = self.get_sharding_strategy(name=name, strategy = self.get_sharding_strategy(
sharding_spec_mapping=sharding_spec_mapping, name=name,
communication_action_mapping=communication_action_mapping) sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy return strategy
...@@ -84,9 +86,9 @@ class WhereGenerator(StrategyGenerator): ...@@ -84,9 +86,9 @@ class WhereGenerator(StrategyGenerator):
return dim_partition_list return dim_partition_list
def collate_strategies(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
''' """
Generate every possible strategies for a where node, and record all strategies into the strategies_vector. Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
''' """
strategy_list = [] strategy_list = []
dimension_length = len(self.op_data["output"].logical_shape) dimension_length = len(self.op_data["output"].logical_shape)
......
...@@ -7,7 +7,7 @@ from .node_handler import NodeHandler ...@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import StrategyGenerator, SumGenerator from .strategy import StrategyGenerator, SumGenerator
__all__ = ['SumHandler'] __all__ = ["SumHandler"]
@operator_registry.register(torch.Tensor.sum) @operator_registry.register(torch.Tensor.sum)
...@@ -55,7 +55,7 @@ class SumHandler(NodeHandler): ...@@ -55,7 +55,7 @@ class SumHandler(NodeHandler):
# sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input # sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input
# sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input # sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input
sum_mapping_dict = {} sum_mapping_dict = {}
if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']: if "keepdim" in self.node.kwargs and self.node.kwargs["keepdim"]:
for i in range(num_dims): for i in range(num_dims):
sum_mapping_dict.update({i: i}) sum_mapping_dict.update({i: i})
else: else:
...@@ -67,7 +67,7 @@ class SumHandler(NodeHandler): ...@@ -67,7 +67,7 @@ class SumHandler(NodeHandler):
assert output_index == self.node._meta_data.dim() assert output_index == self.node._meta_data.dim()
sum_info = (sum_dims, sum_mapping_dict) sum_info = (sum_dims, sum_mapping_dict)
physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info) physical_shape_operand = OperationData(name="sum_info", type=OperationDataType.ARG, data=sum_info)
output_data = self.node._meta_data output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
...@@ -75,7 +75,7 @@ class SumHandler(NodeHandler): ...@@ -75,7 +75,7 @@ class SumHandler(NodeHandler):
mapping = { mapping = {
"input": physical_input_operand, "input": physical_input_operand,
"sum_info": physical_shape_operand, "sum_info": physical_shape_operand,
"output": physical_output_operand "output": physical_output_operand,
} }
return mapping return mapping
...@@ -8,7 +8,7 @@ from .registry import operator_registry ...@@ -8,7 +8,7 @@ from .registry import operator_registry
from .strategy import StrategyGenerator from .strategy import StrategyGenerator
from .strategy.tensor_constructor_generator import TensorConstructorGenerator from .strategy.tensor_constructor_generator import TensorConstructorGenerator
__all__ = ['TensorConstructorHandler'] __all__ = ["TensorConstructorHandler"]
@operator_registry.register(torch.arange) @operator_registry.register(torch.arange)
......
...@@ -7,7 +7,7 @@ from .node_handler import NodeHandler ...@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import StrategyGenerator, TransposeGenerator from .strategy import StrategyGenerator, TransposeGenerator
__all__ = ['TransposeHandler'] __all__ = ["TransposeHandler"]
@operator_registry.register(torch.Tensor.transpose) @operator_registry.register(torch.Tensor.transpose)
...@@ -48,9 +48,9 @@ class TransposeHandler(NodeHandler): ...@@ -48,9 +48,9 @@ class TransposeHandler(NodeHandler):
if transpose_dims[i] < 0: if transpose_dims[i] < 0:
transpose_dims[i] += num_dims transpose_dims[i] += num_dims
physical_shape_operand = OperationData(name='transpose_dims', physical_shape_operand = OperationData(
type=OperationDataType.ARG, name="transpose_dims", type=OperationDataType.ARG, data=list(transpose_dims)
data=list(transpose_dims)) )
output_data = self.node._meta_data output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
...@@ -58,7 +58,7 @@ class TransposeHandler(NodeHandler): ...@@ -58,7 +58,7 @@ class TransposeHandler(NodeHandler):
mapping = { mapping = {
"input": physical_input_operand, "input": physical_input_operand,
"transpose_dims": physical_shape_operand, "transpose_dims": physical_shape_operand,
"output": physical_output_operand "output": physical_output_operand,
} }
return mapping return mapping
...@@ -3,11 +3,11 @@ from typing import Dict, List ...@@ -3,11 +3,11 @@ from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoNodeHandler, NodeHandler from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator from .strategy import StrategyGenerator, UnaryElementwiseGenerator
__all__ = ['UnaryElementwiseHandler'] __all__ = ["UnaryElementwiseHandler"]
@operator_registry.register(torch.Tensor.to) @operator_registry.register(torch.Tensor.to)
...@@ -33,9 +33,9 @@ class UnaryElementwiseHandler(MetaInfoNodeHandler): ...@@ -33,9 +33,9 @@ class UnaryElementwiseHandler(MetaInfoNodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies # use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
data=self.node.args[0]._meta_data) )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "output": physical_output} mapping = {"input": physical_input_operand, "output": physical_output}
......
...@@ -7,7 +7,7 @@ from .node_handler import NodeHandler ...@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import StrategyGenerator, ViewGenerator from .strategy import StrategyGenerator, ViewGenerator
__all__ = ['ViewHandler'] __all__ = ["ViewHandler"]
@operator_registry.register(torch.Tensor.reshape) @operator_registry.register(torch.Tensor.reshape)
...@@ -38,7 +38,7 @@ class ViewHandler(NodeHandler): ...@@ -38,7 +38,7 @@ class ViewHandler(NodeHandler):
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
target_shape = self.node._meta_data.shape target_shape = self.node._meta_data.shape
physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape) physical_shape_operand = OperationData(name="tgt_shape", type=OperationDataType.ARG, data=target_shape)
output_data = self.node._meta_data output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
...@@ -46,7 +46,7 @@ class ViewHandler(NodeHandler): ...@@ -46,7 +46,7 @@ class ViewHandler(NodeHandler):
mapping = { mapping = {
"input": physical_input_operand, "input": physical_input_operand,
"tgt_shape": physical_shape_operand, "tgt_shape": physical_shape_operand,
"output": physical_output_operand "output": physical_output_operand,
} }
return mapping return mapping
import copy import copy
import operator
from typing import Dict, List from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import StrategyGenerator, WhereGenerator from .strategy import StrategyGenerator, WhereGenerator
__all__ = ['WhereHandler'] __all__ = ["WhereHandler"]
@operator_registry.register(torch.where) @operator_registry.register(torch.where)
...@@ -28,27 +27,28 @@ class WhereHandler(NodeHandler): ...@@ -28,27 +27,28 @@ class WhereHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies # use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
physical_condition_operand = OperationData(name=str(self.node.args[0]), physical_condition_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
data=self.node.args[0]._meta_data) )
physical_x_operand = OperationData(name=str(self.node.args[1]), physical_x_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[1]), type=OperationDataType.ARG, data=self.node.args[1]._meta_data
data=self.node.args[1]._meta_data) )
physical_y_operand = OperationData(name=str(self.node.args[2]), physical_y_operand = OperationData(
type=OperationDataType.ARG, name=str(self.node.args[2]), type=OperationDataType.ARG, data=self.node.args[2]._meta_data
data=self.node.args[2]._meta_data) )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
physical_mapping = { physical_mapping = {
"condition": physical_condition_operand, "condition": physical_condition_operand,
"x": physical_x_operand, "x": physical_x_operand,
"y": physical_y_operand, "y": physical_y_operand,
"output": physical_output "output": physical_output,
} }
logical_shape_for_all = self.node._meta_data.shape logical_shape_for_all = self.node._meta_data.shape
logical_mapping = {} logical_mapping = {}
for key, physical_operand in physical_mapping.items(): for key, physical_operand in physical_mapping.items():
logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand, logical_mapping[key] = self.convert_physical_operand_to_logical_operand(
logical_shape_for_all) physical_operand, logical_shape_for_all
)
return logical_mapping, physical_mapping return logical_mapping, physical_mapping
...@@ -64,7 +64,8 @@ class WhereHandler(NodeHandler): ...@@ -64,7 +64,8 @@ class WhereHandler(NodeHandler):
logical_shape = logical_op_data_mapping[key].logical_shape logical_shape = logical_op_data_mapping[key].logical_shape
physical_shape = physical_op_data_mapping[key].logical_shape physical_shape = physical_op_data_mapping[key].logical_shape
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec, logical_shape, physical_shape) logical_sharding_spec, logical_shape, physical_shape
)
strategy.sharding_specs.pop(logical_op_data_mapping[key]) strategy.sharding_specs.pop(logical_op_data_mapping[key])
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}" strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
......
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption'] __all__ = ["SolverOptions", "SolverPerference", "DataloaderOption", "ShardOption"]
class SolverPerference(Enum): class SolverPerference(Enum):
""" """
This enum class is to define the solver preference. This enum class is to define the solver preference.
""" """
STANDARD = 0 STANDARD = 0
DP = 1 DP = 1
TP = 2 TP = 2
...@@ -25,6 +26,7 @@ class ShardOption(Enum): ...@@ -25,6 +26,7 @@ class ShardOption(Enum):
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis. TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes. TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
""" """
STANDARD = 0 STANDARD = 0
SHARD = 1 SHARD = 1
SHARD_LAST_AXIS = 2 SHARD_LAST_AXIS = 2
...@@ -35,6 +37,7 @@ class DataloaderOption(Enum): ...@@ -35,6 +37,7 @@ class DataloaderOption(Enum):
""" """
This enum class is to define the dataloader option. This enum class is to define the dataloader option.
""" """
REPLICATED = 0 REPLICATED = 0
DISTRIBUTED = 1 DISTRIBUTED = 1
...@@ -44,6 +47,7 @@ class SolverOptions: ...@@ -44,6 +47,7 @@ class SolverOptions:
""" """
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
""" """
solver_perference: SolverPerference = SolverPerference.STANDARD solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
shard_option: ShardOption = ShardOption.STANDARD shard_option: ShardOption = ShardOption.STANDARD
...@@ -10,7 +10,6 @@ from colossalai.tensor.comm_spec import CommSpec ...@@ -10,7 +10,6 @@ from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import ( from .constants import (
BCAST_FUNC_OP,
ELEMENTWISE_FUNC_OP, ELEMENTWISE_FUNC_OP,
ELEMENTWISE_METHOD_OP, ELEMENTWISE_METHOD_OP,
ELEMENTWISE_MODULE_OP, ELEMENTWISE_MODULE_OP,
...@@ -18,13 +17,14 @@ from .constants import ( ...@@ -18,13 +17,14 @@ from .constants import (
RESHAPE_METHOD_OP, RESHAPE_METHOD_OP,
) )
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] __all__ = ["OperationDataType", "OperationData", "TrainCycleItem", "MemoryCost", "ShardingStrategy", "StrategiesVector"]
class OperationDataType(Enum): class OperationDataType(Enum):
""" """
An operation can come from the argument list of an operator or the parameter list of a module. An operation can come from the argument list of an operator or the parameter list of a module.
""" """
INPUT = 0 INPUT = 0
ARG = 1 ARG = 1
PARAM = 2 PARAM = 2
...@@ -43,6 +43,7 @@ class OperationData: ...@@ -43,6 +43,7 @@ class OperationData:
data (Any): the value for this data, usually it is a meta tensor. data (Any): the value for this data, usually it is a meta tensor.
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory. logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
""" """
name: str name: str
type: OperationDataType type: OperationDataType
data: Any data: Any
...@@ -69,13 +70,13 @@ class OperationData: ...@@ -69,13 +70,13 @@ class OperationData:
self.logical_shape = _infer_logical_shape(self.data) self.logical_shape = _infer_logical_shape(self.data)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})' return f"OperationData(name={self.name}, type={self.type})"
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
return other.name == self.name return other.name == self.name
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(f'{self.name}') return hash(f"{self.name}")
@dataclass @dataclass
...@@ -88,6 +89,7 @@ class TrainCycleItem: ...@@ -88,6 +89,7 @@ class TrainCycleItem:
fwd (float): the item for the forward pass fwd (float): the item for the forward pass
bwd (float): the item for the backward pass bwd (float): the item for the backward pass
""" """
fwd: Any fwd: Any
bwd: Any bwd: Any
total: Any total: Any
...@@ -104,6 +106,7 @@ class MemoryCost: ...@@ -104,6 +106,7 @@ class MemoryCost:
temp (int): the memory cost incurred by the temporary tensors in bytes. temp (int): the memory cost incurred by the temporary tensors in bytes.
buffer (int): the memory cost incurred by the module buffer in bytes. buffer (int): the memory cost incurred by the module buffer in bytes.
""" """
activation: int = 0 activation: int = 0
parameter: int = 0 parameter: int = 0
temp: int = 0 temp: int = 0
...@@ -120,6 +123,7 @@ class CommType(Enum): ...@@ -120,6 +123,7 @@ class CommType(Enum):
HOOK: the communication action is used to do the grad all reduce. HOOK: the communication action is used to do the grad all reduce.
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
""" """
BEFORE = 0 BEFORE = 0
AFTER = 1 AFTER = 1
HOOK = 2 HOOK = 2
...@@ -137,6 +141,7 @@ class CommAction: ...@@ -137,6 +141,7 @@ class CommAction:
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime, arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
because the args of node may be changed by graph transform passes. because the args of node may be changed by graph transform passes.
""" """
comm_spec: CommSpec = None comm_spec: CommSpec = None
comm_type: CommType = None comm_type: CommType = None
arg_index: int = -1 arg_index: int = -1
...@@ -156,6 +161,7 @@ class ShardingStrategy: ...@@ -156,6 +161,7 @@ class ShardingStrategy:
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None) memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes. input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
""" """
name: str name: str
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
compute_cost: TrainCycleItem = None compute_cost: TrainCycleItem = None
...@@ -200,7 +206,6 @@ class ShardingStrategy: ...@@ -200,7 +206,6 @@ class ShardingStrategy:
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}") raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
def clone(self): def clone(self):
def _deepcopy_dict_vals(data: Dict): def _deepcopy_dict_vals(data: Dict):
return {k: deepcopy(v) for k, v in data.items()} return {k: deepcopy(v) for k, v in data.items()}
...@@ -209,31 +214,34 @@ class ShardingStrategy: ...@@ -209,31 +214,34 @@ class ShardingStrategy:
# Consider the examples below: # Consider the examples below:
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False. # If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items. # In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
communication_actions = _deepcopy_dict_vals( communication_actions = (
self.communication_actions) if self.communication_actions is not None else None _deepcopy_dict_vals(self.communication_actions) if self.communication_actions is not None else None
)
# same reason as communication_actions # same reason as communication_actions
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
compute_cost = deepcopy(self.compute_cost) compute_cost = deepcopy(self.compute_cost)
communication_cost = deepcopy(self.communication_cost) communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost) memory_cost = deepcopy(self.memory_cost)
return ShardingStrategy(name=self.name, return ShardingStrategy(
sharding_specs=sharding_specs, name=self.name,
compute_cost=compute_cost, sharding_specs=sharding_specs,
communication_cost=communication_cost, compute_cost=compute_cost,
memory_cost=memory_cost, communication_cost=communication_cost,
communication_actions=communication_actions, memory_cost=memory_cost,
resharding_costs=resharding_costs) communication_actions=communication_actions,
resharding_costs=resharding_costs,
)
class StrategiesVector(list): class StrategiesVector(list):
''' """
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node. strategies of the node.
Argument: Argument:
node (Node): node for which the list of sharding strategies are generated. node (Node): node for which the list of sharding strategies are generated.
''' """
def __init__(self, node: Node): def __init__(self, node: Node):
super().__init__() super().__init__()
...@@ -245,7 +253,7 @@ class StrategiesVector(list): ...@@ -245,7 +253,7 @@ class StrategiesVector(list):
def check_merge(self): def check_merge(self):
merge_label = False merge_label = False
if self.node.op == 'call_module': if self.node.op == "call_module":
target = self.node.target target = self.node.target
root_module = self.node.graph.owning_module root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target) submod = root_module.get_submodule(target)
...@@ -255,7 +263,7 @@ class StrategiesVector(list): ...@@ -255,7 +263,7 @@ class StrategiesVector(list):
if submod_type in ELEMENTWISE_MODULE_OP: if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True merge_label = True
if self.node.op == 'call_function': if self.node.op == "call_function":
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if self.node.target in ELEMENTWISE_FUNC_OP: if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True merge_label = True
...@@ -267,7 +275,7 @@ class StrategiesVector(list): ...@@ -267,7 +275,7 @@ class StrategiesVector(list):
if self.node.target in RESHAPE_FUNC_OP: if self.node.target in RESHAPE_FUNC_OP:
merge_label = True merge_label = True
if self.node.op == 'call_method': if self.node.op == "call_method":
# we could merge reshape op, because their computation costs are negligible. # we could merge reshape op, because their computation costs are negligible.
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target) method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
if method in RESHAPE_METHOD_OP: if method in RESHAPE_METHOD_OP:
......
...@@ -3,4 +3,4 @@ from .graph_analysis import GraphAnalyser ...@@ -3,4 +3,4 @@ from .graph_analysis import GraphAnalyser
from .solver import Solver from .solver import Solver
from .strategies_constructor import StrategiesConstructor from .strategies_constructor import StrategiesConstructor
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph'] __all__ = ["GraphAnalyser", "Solver", "StrategiesConstructor", "CostGraph"]
...@@ -4,18 +4,18 @@ from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST ...@@ -4,18 +4,18 @@ from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
class CostGraph: class CostGraph:
''' """
A graph data structure to simplify the edge cost graph. It has two main functions: A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as 2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will element-wise operators, transpose, and reduction, into their following nodes. The merging information will
be given by the StrategiesVector depending on the type of target node and following nodes. be given by the StrategiesVector depending on the type of target node and following nodes.
Argument: Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
''' """
def __init__(self, leaf_strategies, simplify=True, forward_only=False): def __init__(self, leaf_strategies, simplify=True, forward_only=False):
self.leaf_strategies = leaf_strategies self.leaf_strategies = leaf_strategies
...@@ -39,10 +39,10 @@ class CostGraph: ...@@ -39,10 +39,10 @@ class CostGraph:
target_node_list.remove(element) target_node_list.remove(element)
def _build_cost_graph(self): def _build_cost_graph(self):
''' """
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node. set to node.
''' """
self.edge_costs = {} self.edge_costs = {}
if self.simplify: if self.simplify:
self.merge_pair = [] self.merge_pair = []
...@@ -84,13 +84,13 @@ class CostGraph: ...@@ -84,13 +84,13 @@ class CostGraph:
if _check_tensor_in_node(node._meta_data): if _check_tensor_in_node(node._meta_data):
children_nodes.append(node) children_nodes.append(node)
setattr(dst_node, 'parents', parent_nodes) setattr(dst_node, "parents", parent_nodes)
setattr(dst_node, 'children', children_nodes) setattr(dst_node, "children", children_nodes)
if self.simplify and strategies_vector.check_merge(): if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes: for followed_node in strategies_vector.predecessor_nodes:
# we only merge node pairs which src node has a tensor element inside. # we only merge node pairs which src node has a tensor element inside.
# This is necessay because the node without a tensor element inside will not # This is necessary because the node without a tensor element inside will not
# be assigned any strategy. # be assigned any strategy.
if _check_tensor_in_node(followed_node._meta_data): if _check_tensor_in_node(followed_node._meta_data):
self.merge_pair.append((followed_node, dst_node)) self.merge_pair.append((followed_node, dst_node))
...@@ -99,7 +99,7 @@ class CostGraph: ...@@ -99,7 +99,7 @@ class CostGraph:
return self.edge_costs[(src_node, dst_node)] return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node): def merge_node(self, src_node, dst_node):
''' """
To merge dst_node into src_node, we need to do it in following steps: To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy 1. For each strategy in dst_node, we need to pick an appropriate strategy
...@@ -119,7 +119,7 @@ class CostGraph: ...@@ -119,7 +119,7 @@ class CostGraph:
Argument: Argument:
src_node(Node): The node will be merged into dst_node. src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node. dst_node(Node): The node to integrate src_node.
''' """
# build merge_map # build merge_map
merge_map = {} merge_map = {}
for src_index, _ in enumerate(src_node.strategies_vector): for src_index, _ in enumerate(src_node.strategies_vector):
...@@ -196,7 +196,7 @@ class CostGraph: ...@@ -196,7 +196,7 @@ class CostGraph:
if not self.simplify: if not self.simplify:
return return
self.merge_pair.reverse() self.merge_pair.reverse()
for (src_node, dst_node) in self.merge_pair: for src_node, dst_node in self.merge_pair:
self.merge_node(src_node, dst_node) self.merge_node(src_node, dst_node)
self.merge_pair.reverse() self.merge_pair.reverse()
reindexing_following_dict = {} reindexing_following_dict = {}
......
...@@ -7,7 +7,7 @@ from torch.fx.node import Node ...@@ -7,7 +7,7 @@ from torch.fx.node import Node
from colossalai.fx.passes.utils import get_node_module from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] __all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"]
@dataclass @dataclass
...@@ -15,6 +15,7 @@ class LiveVariable: ...@@ -15,6 +15,7 @@ class LiveVariable:
""" """
LiveVariable is a data structure to store the meta information of a variable for liveness analysis. LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
""" """
name: str name: str
node: Node node: Node
is_inplace: bool is_inplace: bool
...@@ -55,6 +56,7 @@ class LiveStage: ...@@ -55,6 +56,7 @@ class LiveStage:
""" """
LiveStage is a data structure to record the living variables at this current node. LiveStage is a data structure to record the living variables at this current node.
""" """
name: str name: str
node: Node node: Node
all_live_vars: LiveVariableVector all_live_vars: LiveVariableVector
...@@ -62,7 +64,6 @@ class LiveStage: ...@@ -62,7 +64,6 @@ class LiveStage:
class GraphAnalyser: class GraphAnalyser:
def __init__(self, gm: GraphModule): def __init__(self, gm: GraphModule):
self._gm = gm self._gm = gm
self._graph = gm.graph self._graph = gm.graph
...@@ -83,7 +84,7 @@ class GraphAnalyser: ...@@ -83,7 +84,7 @@ class GraphAnalyser:
def liveness_analysis(self) -> List[LiveStage]: def liveness_analysis(self) -> List[LiveStage]:
""" """
Analyse the graph to obtain the variable liveness information. This function returns Analyses the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
""" """
compute_nodes = self.graph.nodes compute_nodes = self.graph.nodes
...@@ -91,7 +92,7 @@ class GraphAnalyser: ...@@ -91,7 +92,7 @@ class GraphAnalyser:
# checked: record all variables created since the first stage # checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage. # all: record the live variables only exist until the current stage.
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage. # this can be different from the `checked list`` as some variables may be destroyed prior to this stage.
# unique: record the unique live variables only exist until the current stage. # unique: record the unique live variables only exist until the current stage.
# this is different from `all list` as some variables are duplicated. # this is different from `all list` as some variables are duplicated.
checked_variables = LiveVariableVector() checked_variables = LiveVariableVector()
...@@ -103,20 +104,20 @@ class GraphAnalyser: ...@@ -103,20 +104,20 @@ class GraphAnalyser:
# find new living variables # # find new living variables #
############################# #############################
# detect whether the current op is an in-place op # detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplciate var # if it is an in-place op, we would deem it as a duplicate var
is_inplace = False is_inplace = False
if node.op == 'call_function': if node.op == "call_function":
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
if node.kwargs.get('inplace', False): if node.kwargs.get("inplace", False):
is_inplace = True is_inplace = True
elif node.op == 'call_module': elif node.op == "call_module":
# to check if this is an inplace op such as torch.nn.Relu(inplace=True) # to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node) module = get_node_module(node)
if getattr(module, 'inplace', False): if getattr(module, "inplace", False):
is_inplace = True is_inplace = True
# add the output var # add the output var
meta = getattr(node, '_meta_data', None) getattr(node, "_meta_data", None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace) live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace: if not is_inplace:
unique_live_vars.append(live_var) unique_live_vars.append(live_var)
...@@ -138,10 +139,12 @@ class GraphAnalyser: ...@@ -138,10 +139,12 @@ class GraphAnalyser:
# this should be completed if we are able to trace the backward compute graph # this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict # add this stage to liveness dict
stage = LiveStage(name=node.name, stage = LiveStage(
node=node, name=node.name,
all_live_vars=all_live_variables.copy(), node=node,
unique_live_vars=unique_live_vars.copy()) all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy(),
)
# if a LiveStage is covered by another LiveStage, we just keep the larger one. # if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False replace = False
for index, prev_stage in enumerate(liveness_list): for index, prev_stage in enumerate(liveness_list):
......
...@@ -21,34 +21,35 @@ try: ...@@ -21,34 +21,35 @@ try:
import pulp import pulp
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except: except:
warnings.warn(f'please install the pulp') warnings.warn(f"please install the pulp")
__all___ = ['Solver'] __all___ = ["Solver"]
class Solver: class Solver:
def __init__(
def __init__(self, self,
graph: Graph, graph: Graph,
strategies_constructor: StrategiesConstructor, strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph, cost_graph: CostGraph,
graph_analyser: GraphAnalyser = None, graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0, memory_budget: float = -1.0,
solution_numbers: int = 1, solution_numbers: int = 1,
forward_only: bool = False, forward_only: bool = False,
memory_increasing_coefficient: float = 1.3, memory_increasing_coefficient: float = 1.3,
verbose=False): verbose=False,
''' ):
"""
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument: Argument:
graph: The computing graph to be optimized. graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph. strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph. cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. graph_analyser: graph_analyser will analyses the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution. memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
''' """
self.graph = graph self.graph = graph
self.strategies_constructor = strategies_constructor self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph self.cost_graph = cost_graph
...@@ -75,11 +76,11 @@ class Solver: ...@@ -75,11 +76,11 @@ class Solver:
self.verbose = verbose self.verbose = verbose
def _recover_merged_node_strategy(self): def _recover_merged_node_strategy(self):
''' """
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node. node.
''' """
for node_index, node in enumerate(self.nodes): for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge(): if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy # the merged node has only one input, and its strategies follow the input sharding strategy
...@@ -98,9 +99,9 @@ class Solver: ...@@ -98,9 +99,9 @@ class Solver:
return node_index_dict return node_index_dict
def _prepare_data_for_solver(self): def _prepare_data_for_solver(self):
''' """
Extract information from components for solver. Extract information from components for solver.
''' """
node_nums = len(self.leaf_strategies) node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget memory_budget = self.memory_budget
...@@ -190,23 +191,40 @@ class Solver: ...@@ -190,23 +191,40 @@ class Solver:
# omit initial value for nodes # omit initial value for nodes
s_init_np = None s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose return (
node_nums,
def _call_solver_serialized_args(self, memory_budget,
node_nums, strategies_len,
memory_budget, following_nodes,
strategies_len, edge_pairs,
following_nodes, alias_set,
edge_pairs, liveness_set,
alias_set, compute_costs,
liveness_set, communication_costs,
compute_costs, memory_costs,
communication_costs, resharding_costs,
memory_costs, alias_convert_costs,
resharding_costs, s_init_np,
alias_convert_costs, self.verbose,
s_init_np=None, )
verbose=True):
def _call_solver_serialized_args(
self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None,
verbose=True,
):
""" """
Call the solver with serialized arguments. Call the solver with serialized arguments.
""" """
...@@ -235,18 +253,18 @@ class Solver: ...@@ -235,18 +253,18 @@ class Solver:
s_follow = following_nodes s_follow = following_nodes
s_alias = alias_set s_alias = alias_set
E = edge_pairs.reshape((-1, 2)) # noqa E = edge_pairs.reshape((-1, 2)) # noqa
r = [] r = []
pt = 0 pt = 0
edge_set = set() edge_set = set()
for (i, j) in E: for i, j in E:
prod_length = strategies_len[i] * strategies_len[j] prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set: if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}") raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j)) edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length]) r.append(resharding_costs[pt : pt + prod_length])
pt += prod_length pt += prod_length
assert pt == len(resharding_costs) assert pt == len(resharding_costs)
...@@ -268,7 +286,6 @@ class Solver: ...@@ -268,7 +286,6 @@ class Solver:
# L.append(liveness_set[pt:pt + length]) # L.append(liveness_set[pt:pt + length])
# pt += length # pt += length
# assert pt == len(liveness_set) # assert pt == len(liveness_set)
v = []
pt = 0 pt = 0
c = [] c = []
...@@ -277,9 +294,9 @@ class Solver: ...@@ -277,9 +294,9 @@ class Solver:
pt = 0 pt = 0
for i in range(node_nums): for i in range(node_nums):
length = strategies_len[i] length = strategies_len[i]
c.append(compute_costs[pt:pt + length]) c.append(compute_costs[pt : pt + length])
d.append(communication_costs[pt:pt + length]) d.append(communication_costs[pt : pt + length])
m.append(memory_costs[pt:pt + length]) m.append(memory_costs[pt : pt + length])
pt += length pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
...@@ -319,7 +336,7 @@ class Solver: ...@@ -319,7 +336,7 @@ class Solver:
e = [] e = []
num_edges = 0 num_edges = 0
map_edge_to_idx = {} map_edge_to_idx = {}
for (idx, (i, j)) in enumerate(E): for idx, (i, j) in enumerate(E):
if len(s[i]) == 1: if len(s[i]) == 1:
e.append(s[j]) e.append(s[j])
elif len(s[j]) == 1: elif len(s[j]) == 1:
...@@ -340,7 +357,7 @@ class Solver: ...@@ -340,7 +357,7 @@ class Solver:
###################################### ######################################
if s_init_np is not None: if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3)) s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init: for idx, value, fix in s_init:
for i in range(len(s[idx])): for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value) s[idx][i].setInitialValue(i == value)
if fix: if fix:
...@@ -393,7 +410,7 @@ class Solver: ...@@ -393,7 +410,7 @@ class Solver:
# (d). specified by `cat="Binary"` # (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E): for idx, (i, j) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1: if strategies_len[i] == 1 or strategies_len[j] == 1:
continue continue
...@@ -402,13 +419,13 @@ class Solver: ...@@ -402,13 +419,13 @@ class Solver:
# (f) # (f)
for row in range(len(s[i])): for row in range(len(s[i])):
C = len(s[j]) # noqa C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g) # (g)
for col in range(len(s[j])): for col in range(len(s[j])):
R = len(s[i]) # noqa R = len(s[i]) # noqa
C = len(s[j]) # noqa C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h) # (h)
...@@ -434,7 +451,8 @@ class Solver: ...@@ -434,7 +451,8 @@ class Solver:
msg = verbose msg = verbose
time_limit = 600 time_limit = 600
assert "COIN_CMD" in pulp.listSolvers( assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") onlyAvailable=True
), "Please install ILP solvers by 'sudo apt install coinor-cbc'"
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
...@@ -444,13 +462,13 @@ class Solver: ...@@ -444,13 +462,13 @@ class Solver:
objective = pulp.value(prob.objective) objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0 objective = float(objective) if objective is not None else -1.0
if verbose: if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}")
f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}") print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]: if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. " raise RuntimeError(
"Please increase the memory budget.") "Cannot run the function under the given memory budget. " "Please increase the memory budget."
)
# Get and check results # Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32) s_val = np.full((node_nums,), -1, dtype=np.int32)
...@@ -458,7 +476,7 @@ class Solver: ...@@ -458,7 +476,7 @@ class Solver:
s_val[i] = get_non_zero_index(s[i]) s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32) e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E): for idx, (i, j) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx]) e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j]) i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j]) j_spec_index = e_val[idx] % len(s[j])
......
import builtins
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch import torch
from torch.fx import Graph, Node from torch.fx import Graph
from colossalai.auto_parallel.tensor_shard.node_handler import ( from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler, GetattrHandler,
...@@ -14,13 +8,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler import ( ...@@ -14,13 +8,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
operator_registry, operator_registry,
) )
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from ..options import DataloaderOption, SolverOptions from ..options import DataloaderOption, SolverOptions
__all__ = ['StrategiesConstructor'] __all__ = ["StrategiesConstructor"]
class StrategiesConstructor: class StrategiesConstructor:
...@@ -35,7 +28,7 @@ class StrategiesConstructor: ...@@ -35,7 +28,7 @@ class StrategiesConstructor:
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes) self.nodes = list(graph.nodes)
self.device_mesh = device_mesh self.device_mesh = device_mesh
...@@ -46,11 +39,11 @@ class StrategiesConstructor: ...@@ -46,11 +39,11 @@ class StrategiesConstructor:
self.alias_set = None self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector): def remove_duplicated_strategy(self, strategies_vector):
''' """
In build_strategies_and_cost method, we may produce some duplicated strategies. In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name. In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place. Note that this operation is in-place.
''' """
name_checklist = [] name_checklist = []
remove_list = [] remove_list = []
for strategy in strategies_vector: for strategy in strategies_vector:
...@@ -62,7 +55,6 @@ class StrategiesConstructor: ...@@ -62,7 +55,6 @@ class StrategiesConstructor:
strategies_vector.remove(strategy) strategies_vector.remove(strategy)
def generate_alias_set(self): def generate_alias_set(self):
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies] node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10) common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
...@@ -83,7 +75,7 @@ class StrategiesConstructor: ...@@ -83,7 +75,7 @@ class StrategiesConstructor:
""" """
def _check_no_strategy_for_node(node): def _check_no_strategy_for_node(node):
if node.op in ('placeholder', 'get_attr', 'output'): if node.op in ("placeholder", "get_attr", "output"):
return False return False
def _check_no_strategy_for_data(data): def _check_no_strategy_for_data(data):
...@@ -102,83 +94,93 @@ class StrategiesConstructor: ...@@ -102,83 +94,93 @@ class StrategiesConstructor:
if _check_no_strategy_for_node(node): if _check_no_strategy_for_node(node):
self.no_strategy_nodes.append(node) self.no_strategy_nodes.append(node)
pass
# placeholder node # placeholder node
elif node.op == 'placeholder': elif node.op == "placeholder":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
placeholder_option = 'distributed' placeholder_option = "distributed"
else: else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' assert (
placeholder_option = 'replicated' self.solver_options.dataloader_option == DataloaderOption.REPLICATED
placeholder_handler = PlaceholderHandler(node, ), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
self.device_mesh, placeholder_option = "replicated"
strategies_vector, placeholder_handler = PlaceholderHandler(
placeholder_option=placeholder_option) node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option
)
placeholder_handler.register_strategy() placeholder_handler.register_strategy()
# get_attr node # get_attr node
elif node.op == 'get_attr': elif node.op == "get_attr":
getattr_handler = GetattrHandler(node, getattr_handler = GetattrHandler(
self.device_mesh, node,
strategies_vector, self.device_mesh,
shard_option=self.solver_options.shard_option, strategies_vector,
solver_perference=self.solver_options.solver_perference) shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
getattr_handler.register_strategy() getattr_handler.register_strategy()
# call_module node # call_module node
elif node.op == 'call_module': elif node.op == "call_module":
target = node.target target = node.target
submod = self.root_module.get_submodule(target) submod = self.root_module.get_submodule(target)
submod_type = type(submod) submod_type = type(submod)
handler = operator_registry.get(submod_type)(node, handler = operator_registry.get(submod_type)(
self.device_mesh, node,
strategies_vector, self.device_mesh,
shard_option=self.solver_options.shard_option, strategies_vector,
solver_perference=self.solver_options.solver_perference) shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
handler.register_strategy() handler.register_strategy()
# attach strategies_info to node # attach strategies_info to node
if hasattr(handler, 'strategies_info'): if hasattr(handler, "strategies_info"):
setattr(node, 'strategies_info', handler.strategies_info) setattr(node, "strategies_info", handler.strategies_info)
# call_function node # call_function node
elif node.op == 'call_function': elif node.op == "call_function":
target = node.target target = node.target
handler = operator_registry.get(target)(node, handler = operator_registry.get(target)(
self.device_mesh, node,
strategies_vector, self.device_mesh,
shard_option=self.solver_options.shard_option, strategies_vector,
solver_perference=self.solver_options.solver_perference) shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
handler.register_strategy() handler.register_strategy()
# attach strategies_info to node # attach strategies_info to node
if hasattr(handler, 'strategies_info'): if hasattr(handler, "strategies_info"):
setattr(node, 'strategies_info', handler.strategies_info) setattr(node, "strategies_info", handler.strategies_info)
# call_method node # call_method node
elif node.op == 'call_method': elif node.op == "call_method":
method = getattr(node.args[0]._meta_data.__class__, node.target) method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node, handler = operator_registry.get(method)(
self.device_mesh, node,
strategies_vector, self.device_mesh,
shard_option=self.solver_options.shard_option, strategies_vector,
solver_perference=self.solver_options.solver_perference) shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
handler.register_strategy() handler.register_strategy()
# attach strategies_info to node # attach strategies_info to node
if hasattr(handler, 'strategies_info'): if hasattr(handler, "strategies_info"):
setattr(node, 'strategies_info', handler.strategies_info) setattr(node, "strategies_info", handler.strategies_info)
# output node # output node
elif node.op == 'output': elif node.op == "output":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
output_option = 'distributed' output_option = "distributed"
else: else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' assert (
output_option = 'replicated' self.solver_options.dataloader_option == DataloaderOption.REPLICATED
), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
output_option = "replicated"
output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy() output_handler.register_strategy()
self.remove_duplicated_strategy(strategies_vector) self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector) setattr(node, "strategies_vector", strategies_vector)
self.leaf_strategies.append(strategies_vector) self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector self.strategy_map[node] = strategies_vector
......
...@@ -17,9 +17,21 @@ from .sharding import ( ...@@ -17,9 +17,21 @@ from .sharding import (
) )
__all__ = [ __all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', "BroadcastType",
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' "get_broadcast_shape",
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', "is_broadcastable",
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map', "recover_sharding_spec_for_broadcast_shape",
'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict' "generate_resharding_costs",
"generate_sharding_spec",
"ignore_sharding_exception",
"check_sharding_spec_validity" "transpose_partition_dim",
"update_partition_dim",
"enumerate_all_possible_1d_sharding",
"enumerate_all_possible_2d_sharding",
"generate_sharding_size",
"comm_actions_for_oprands",
"pytree_map",
"detect_reshape_mapping",
"check_keep_sharding_status",
"infer_output_dim_partition_dict",
] ]
...@@ -14,14 +14,17 @@ from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec ...@@ -14,14 +14,17 @@ from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [ __all__ = [
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape', "BroadcastType",
'comm_actions_for_oprands' "is_broadcastable",
"get_broadcast_shape",
"recover_sharding_spec_for_broadcast_shape",
"comm_actions_for_oprands",
] ]
class BroadcastType(Enum): class BroadcastType(Enum):
EQUAL = auto() EQUAL = auto()
PADDDING = auto() PADDING = auto()
MULTIPLE = auto() MULTIPLE = auto()
...@@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: ...@@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
""" """
Compute the broadcast shape given two shapes. Compute the broadcast shape given two shapes.
""" """
assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable' assert is_broadcastable(shape1, shape2), f"{shape1} and {shape2} are not broadcastable"
shape1_reverse = shape1[::-1] shape1_reverse = shape1[::-1]
shape2_reverse = shape2[::-1] shape2_reverse = shape2[::-1]
min_common_dim = min(len(shape1), len(shape2)) min_common_dim = min(len(shape1), len(shape2))
...@@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape): ...@@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
logical_num_dims = len(logical_shape) logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape) physical_num_dims = len(physical_shape)
assert logical_num_dims >= physical_num_dims, \ assert (
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!' logical_num_dims >= physical_num_dims
), "The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!"
# track the dim and its broadcasting type # track the dim and its broadcasting type
logical_dim_broadcast_info = {} logical_dim_broadcast_info = {}
...@@ -69,24 +73,25 @@ def get_broadcast_dim_info(logical_shape, physical_shape): ...@@ -69,24 +73,25 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
for i in range(logical_num_dims): for i in range(logical_num_dims):
# get the trailing dim size # get the trailing dim size
logical_dim_idx = logical_num_dims - i - 1 logical_dim_idx = logical_num_dims - i - 1
phyiscal_dim_idx = physical_num_dims - i - 1 physical_dim_idx = physical_num_dims - i - 1
logical_dim_size = logical_shape[logical_dim_idx] logical_dim_size = logical_shape[logical_dim_idx]
if phyiscal_dim_idx >= 0: if physical_dim_idx >= 0:
physical_dim_size = physical_shape[phyiscal_dim_idx] physical_dim_size = physical_shape[physical_dim_idx]
if physical_dim_size == logical_dim_size: if physical_dim_size == logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size: elif physical_dim_size == 1 and physical_dim_size != logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE
else: else:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDING
return logical_dim_broadcast_info return logical_dim_broadcast_info
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, def recover_sharding_spec_for_broadcast_shape(
physical_shape: torch.Size) -> ShardingSpec: logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, physical_shape: torch.Size
) -> ShardingSpec:
""" """
This function computes the sharding spec for the physical shape of a broadcast tensor. This function computes the sharding spec for the physical shape of a broadcast tensor.
...@@ -117,22 +122,25 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe ...@@ -117,22 +122,25 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
for shape_dim, mesh_dim in logical_dim_partition.items(): for shape_dim, mesh_dim in logical_dim_partition.items():
logical_broadcast_type = logical_dim_broadcast_info[shape_dim] logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE: if logical_broadcast_type == BroadcastType.PADDING or logical_broadcast_type == BroadcastType.MULTIPLE:
removed_dims.extend(mesh_dim) removed_dims.extend(mesh_dim)
else: else:
# get the corresponding physical dim # get the corresponding physical dim
physical_dim = physical_num_dims - (logical_num_dims - shape_dim) physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
physical_dim_partition[physical_dim] = mesh_dim physical_dim_partition[physical_dim] = mesh_dim
physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh, physical_sharding_spec = ShardingSpec(
entire_shape=physical_shape, device_mesh=logical_sharding_spec.device_mesh,
dim_partition_dict=physical_dim_partition) entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition,
)
return physical_sharding_spec, removed_dims return physical_sharding_spec, removed_dims
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData, def comm_actions_for_oprands(
sharding_spec: ShardingSpec) -> CommAction: node: Node, removed_dims: List[int], op_data: OperationData, sharding_spec: ShardingSpec
) -> CommAction:
""" """
This method is used to generate communication actions for oprands which lose information This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape. during convert logical shape to physical shape.
...@@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera ...@@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
if len(removed_dims) == 1: if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh # if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0] removed_dims = removed_dims[0]
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, comm_spec = CommSpec(
sharding_spec=sharding_spec, comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=removed_dims) sharding_spec=sharding_spec,
logical_process_axis=removed_dims,
)
if op_data.type == OperationDataType.PARAM: if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK comm_type = CommType.HOOK
else: else:
...@@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera ...@@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
for index, arg in enumerate(node.args): for index, arg in enumerate(node.args):
if op_data.name == str(arg): if op_data.name == str(arg):
arg_index = index arg_index = index
assert arg_index >= 0, f'op_data should be an argument of node.' assert arg_index >= 0, f"op_data should be an argument of node."
comm_action = CommAction( comm_action = CommAction(
comm_spec=comm_spec, comm_spec=comm_spec,
comm_type=comm_type, comm_type=comm_type,
......
...@@ -14,11 +14,12 @@ from colossalai.tensor.sharding_spec import ShardingSpec ...@@ -14,11 +14,12 @@ from colossalai.tensor.sharding_spec import ShardingSpec
from ..constants import INFINITY_COST from ..constants import INFINITY_COST
__all__ = ['generate_sharding_spec', 'generate_resharding_costs'] __all__ = ["generate_sharding_spec", "generate_resharding_costs"]
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, def generate_sharding_spec(
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, dim_partition_dict: Dict[int, List[int]]
) -> ShardingSpec:
""" """
Generate the sharding spec of the tensor based on the given dim_partition_dict. Generate the sharding spec of the tensor based on the given dim_partition_dict.
...@@ -30,7 +31,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic ...@@ -30,7 +31,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
""" """
if isinstance(input_, Node): if isinstance(input_, Node):
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data' assert hasattr(input_, "_meta_data"), f"The given node has no attribute _meta_data"
meta_tensor = input_._meta_data meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None" assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape shape = meta_tensor.shape
...@@ -38,24 +39,27 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic ...@@ -38,24 +39,27 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
shape = input_.shape shape = input_.shape
else: else:
raise TypeError( raise TypeError(
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' f"We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected."
) )
for dim_index, sharding_index_list in dim_partition_dict.items(): for dim_index, sharding_index_list in dim_partition_dict.items():
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1) sharding_size = reduce(operator.mul, sharding_list, 1)
assert shape[ assert (
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' shape[dim_index] % sharding_size == 0
), f"we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions."
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec return sharding_spec
def generate_resharding_costs(nodes: List[Node], def generate_resharding_costs(
sharding_specs: List[ShardingSpec], nodes: List[Node],
count_backward: Optional[bool] = True, sharding_specs: List[ShardingSpec],
dtype: Optional[torch.dtype] = None, count_backward: Optional[bool] = True,
index=None): dtype: Optional[torch.dtype] = None,
''' index=None,
):
"""
Compute the resharding costs with this specific strategy. Compute the resharding costs with this specific strategy.
Argument: Argument:
...@@ -63,7 +67,7 @@ def generate_resharding_costs(nodes: List[Node], ...@@ -63,7 +67,7 @@ def generate_resharding_costs(nodes: List[Node],
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
''' """
# The resharding_cost of weight is counted due to sharing weight cases. # The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {} resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
...@@ -76,38 +80,39 @@ def generate_resharding_costs(nodes: List[Node], ...@@ -76,38 +80,39 @@ def generate_resharding_costs(nodes: List[Node],
for strategy in input_node.strategies_vector: for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec input_sharding_spec = strategy.output_sharding_spec
if not isinstance(input_sharding_spec, ShardingSpec): if not isinstance(input_sharding_spec, ShardingSpec):
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.' assert isinstance(input_sharding_spec, list), "only ShardingSpec or List[ShardingSpec] is expected."
input_sharding_spec = input_sharding_spec[index] input_sharding_spec = input_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' assert isinstance(input_sharding_spec, ShardingSpec), f"The input node should NOT be a tuple of tensor."
try: try:
# compute the resharding cost # compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency( _, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec) input_sharding_spec, input_spec
)
# we need multiply the size of elem dtype to get correct communication cost # we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
except AssertionError as e: except AssertionError as e:
warnings.warn(f'{e}') warnings.warn(f"{e}")
resharding_cost = INFINITY_COST resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost) resharding_costs[input_node].append(resharding_cost)
return resharding_costs return resharding_costs
def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20): def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20):
''' """
Find the largest repeat blocks in the graph, whose length is larger than the threshold. Find the largest repeat blocks in the graph, whose length is larger than the threshold.
Args: Args:
gm (GraphModule): the graph module to be analyzed. gm (GraphModule): the graph module to be analyzed.
common_length_threshold (int): the threshold of the repeat block length. common_length_threshold (int): the threshold of the repeat block length.
''' """
# graph = gm.graph # graph = gm.graph
def _process_args(args): def _process_args(args):
new_args = [] new_args = []
for arg in args: for arg in args:
if hasattr(arg, '_meta_data'): if hasattr(arg, "_meta_data"):
meta_data = arg._meta_data meta_data = arg._meta_data
else: else:
meta_data = arg meta_data = arg
...@@ -145,7 +150,7 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt ...@@ -145,7 +150,7 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
return False return False
for index, node in enumerate(node_list): for index, node in enumerate(node_list):
if node.op == 'call_module': if node.op == "call_module":
target = node.target target = node.target
submod = root_module.get_submodule(target) submod = root_module.get_submodule(target)
submod_type = type(submod) submod_type = type(submod)
...@@ -155,12 +160,12 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt ...@@ -155,12 +160,12 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
new_args = _process_args(node.args) new_args = _process_args(node.args)
if node.op != 'get_attr': if node.op != "get_attr":
hash_key = (node.op, target, *new_args) hash_key = (node.op, target, *new_args)
else: else:
hash_key = (node.op,) hash_key = (node.op,)
setattr(node, 'hash_key', hash_key) setattr(node, "hash_key", hash_key)
hash_value_to_node_dict = {} hash_value_to_node_dict = {}
...@@ -179,7 +184,7 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt ...@@ -179,7 +184,7 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
# the comparison will be triggered if a common node appears # the comparison will be triggered if a common node appears
if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2: if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2:
start_index_list = hash_value_to_node_dict[hash(node.hash_key)] start_index_list = hash_value_to_node_dict[hash(node.hash_key)]
check_block_list = [node_list[start:start + max_common_length] for start in start_index_list] check_block_list = [node_list[start : start + max_common_length] for start in start_index_list]
common_label = True common_label = True
if not _all_equal(check_block_list, _check_node_list_equal): if not _all_equal(check_block_list, _check_node_list_equal):
...@@ -201,6 +206,6 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt ...@@ -201,6 +206,6 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
# recover common subgraph from the index # recover common subgraph from the index
common_blocks = [] common_blocks = []
for start in common_blocks_index: for start in common_blocks_index:
common_blocks.append(node_list[start:start + max_common_length]) common_blocks.append(node_list[start : start + max_common_length])
return common_blocks return common_blocks
import functools import functools
from typing import Any, Callable, Dict, List, Tuple, Type, Union from typing import Any, Callable, Tuple, Type, Union
import torch import torch
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
__all__ = ['ignore_sharding_exception', 'pytree_map'] __all__ = ["ignore_sharding_exception", "pytree_map"]
def ignore_sharding_exception(func): def ignore_sharding_exception(func):
...@@ -46,31 +46,34 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens ...@@ -46,31 +46,34 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
# make sure all dims are covered in sharding spec # make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence) sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim() tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] num_devices_in_col = sharding_spec.device_mesh.shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] num_devices_in_row = sharding_spec.device_mesh.shape[1]
assert sharding_len == tensor_num_dim, \ assert (
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' sharding_len == tensor_num_dim
), f"The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape})."
# make sure the sharding is valid for each dim # make sure the sharding is valid for each dim
for i in range(tensor_num_dim): for i in range(tensor_num_dim):
dim_size = tensor.shape[i] dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i] dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'): if str(dim_spec).startswith("S"):
devices_str = str(dim_spec).lstrip('S') devices_str = str(dim_spec).lstrip("S")
num_devices = 1 num_devices = 1
if '0' in devices_str: if "0" in devices_str:
num_devices *= num_devices_in_col num_devices *= num_devices_in_col
if '1' in devices_str: if "1" in devices_str:
num_devices *= num_devices_in_row num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \ assert (
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' dim_size >= num_devices and dim_size % num_devices == 0
), f"The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices."
# make sure the entire shape matches the physical tensor shape # make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape, \ assert (
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}' sharding_spec.entire_shape == tensor.shape
), f"The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}"
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
......
...@@ -6,12 +6,13 @@ import torch ...@@ -6,12 +6,13 @@ import torch
class PreviousStatus(Enum): class PreviousStatus(Enum):
""" """
This class shows the status of previous comparision. This class shows the status of previous comparison.
""" """
RESET = 0 RESET = 0
# ORIGIN means the dimension size of original tensor is larger in the previous comparision. # ORIGIN means the dimension size of original tensor is larger in the previous comparison.
ORIGIN = 1 ORIGIN = 1
# TGT means the dimension size of target tensor is larger in the previous comparision. # TGT means the dimension size of target tensor is larger in the previous comparison.
TGT = 2 TGT = 2
...@@ -91,7 +92,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D ...@@ -91,7 +92,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
tgt_index += 1 tgt_index += 1
if previous_label == PreviousStatus.TGT: if previous_label == PreviousStatus.TGT:
# if the target dimension size is larger in the previous comparision, which means # if the target dimension size is larger in the previous comparison, which means
# the origin dimension size has already accumulated larger than target dimension size, so # the origin dimension size has already accumulated larger than target dimension size, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict. # we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
...@@ -111,7 +112,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D ...@@ -111,7 +112,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
origin_index += 1 origin_index += 1
if previous_label == PreviousStatus.ORIGIN: if previous_label == PreviousStatus.ORIGIN:
# if the origin element is larger in the previous comparision, which means # if the origin element is larger in the previous comparison, which means
# the target element has already accumulated larger than origin element, so # the target element has already accumulated larger than origin element, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict. # we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
...@@ -130,8 +131,9 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D ...@@ -130,8 +131,9 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
return reshape_mapping_dict return reshape_mapping_dict
def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], def check_keep_sharding_status(
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool: input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]
) -> bool:
""" """
This method is used to check whether the reshape operation could implement without converting This method is used to check whether the reshape operation could implement without converting
the input to fully replicated status. the input to fully replicated status.
...@@ -139,7 +141,7 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], ...@@ -139,7 +141,7 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
Rule: Rule:
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple, For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
the function will return false. the function will return false.
To illustrate this issue, there are two cases to analyse: To illustrate this issue, there are two cases to analyze:
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal 1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
operation without distributed tensor. operation without distributed tensor.
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape 2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape
...@@ -172,14 +174,16 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], ...@@ -172,14 +174,16 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
return True return True
def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]], def infer_output_dim_partition_dict(
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]: input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]
) -> Dict[Tuple[int], Tuple[int]]:
""" """
This method is used to infer the output dim partition dict for a reshape operation, This method is used to infer the output dim partition dict for a reshape operation,
given the input dim partition dict and reshape mapping dict. given the input dim partition dict and reshape mapping dict.
""" """
assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \ assert check_keep_sharding_status(
'we only infer output dim partition dict for the reshape operation could keep sharding spec.' input_dim_partition_dict, reshape_mapping_dict
), "we only infer output dim partition dict for the reshape operation could keep sharding spec."
sharded_dims = list(input_dim_partition_dict.keys()) sharded_dims = list(input_dim_partition_dict.keys())
output_dim_partition_dict = {} output_dim_partition_dict = {}
for input_dims, output_dims in reshape_mapping_dict.items(): for input_dims, output_dims in reshape_mapping_dict.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment