Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['UnaryElementwiseGenerator']
__all__ = ["UnaryElementwiseGenerator"]
class UnaryElementwiseGenerator(FollowingStrategyGenerator):
......@@ -21,12 +21,12 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
"input": self._compute_size_in_bytes(strategy, "input"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -44,8 +44,9 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
total_mem_cost = MemoryCost(
activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
......@@ -69,9 +70,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
strategy_list.append(strategy)
return strategy_list
......@@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from .strategy_generator import StrategyGenerator
__all__ = ['WhereGenerator']
__all__ = ["WhereGenerator"]
class WhereGenerator(StrategyGenerator):
......@@ -26,14 +26,14 @@ class WhereGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
"""
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping = {
'condition': self._compute_size_in_bytes(strategy, "condition"),
'x': self._compute_size_in_bytes(strategy, "x"),
'y': self._compute_size_in_bytes(strategy, "y"),
'output': self._compute_size_in_bytes(strategy, "output")
"condition": self._compute_size_in_bytes(strategy, "condition"),
"x": self._compute_size_in_bytes(strategy, "x"),
"y": self._compute_size_in_bytes(strategy, "y"),
"output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
......@@ -59,7 +59,7 @@ class WhereGenerator(StrategyGenerator):
"condition": dim_partition,
"x": dim_partition,
"y": dim_partition,
"output": dim_partition
"output": dim_partition,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
......@@ -67,9 +67,11 @@ class WhereGenerator(StrategyGenerator):
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}'
communication_action_mapping = {}
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping,
)
return strategy
......@@ -84,9 +86,9 @@ class WhereGenerator(StrategyGenerator):
return dim_partition_list
def collate_strategies(self) -> List[ShardingStrategy]:
'''
"""
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
'''
"""
strategy_list = []
dimension_length = len(self.op_data["output"].logical_shape)
......
......@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, SumGenerator
__all__ = ['SumHandler']
__all__ = ["SumHandler"]
@operator_registry.register(torch.Tensor.sum)
......@@ -55,7 +55,7 @@ class SumHandler(NodeHandler):
# sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input
# sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input
sum_mapping_dict = {}
if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']:
if "keepdim" in self.node.kwargs and self.node.kwargs["keepdim"]:
for i in range(num_dims):
sum_mapping_dict.update({i: i})
else:
......@@ -67,7 +67,7 @@ class SumHandler(NodeHandler):
assert output_index == self.node._meta_data.dim()
sum_info = (sum_dims, sum_mapping_dict)
physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info)
physical_shape_operand = OperationData(name="sum_info", type=OperationDataType.ARG, data=sum_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
......@@ -75,7 +75,7 @@ class SumHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"sum_info": physical_shape_operand,
"output": physical_output_operand
"output": physical_output_operand,
}
return mapping
......@@ -8,7 +8,7 @@ from .registry import operator_registry
from .strategy import StrategyGenerator
from .strategy.tensor_constructor_generator import TensorConstructorGenerator
__all__ = ['TensorConstructorHandler']
__all__ = ["TensorConstructorHandler"]
@operator_registry.register(torch.arange)
......
......@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, TransposeGenerator
__all__ = ['TransposeHandler']
__all__ = ["TransposeHandler"]
@operator_registry.register(torch.Tensor.transpose)
......@@ -48,9 +48,9 @@ class TransposeHandler(NodeHandler):
if transpose_dims[i] < 0:
transpose_dims[i] += num_dims
physical_shape_operand = OperationData(name='transpose_dims',
type=OperationDataType.ARG,
data=list(transpose_dims))
physical_shape_operand = OperationData(
name="transpose_dims", type=OperationDataType.ARG, data=list(transpose_dims)
)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
......@@ -58,7 +58,7 @@ class TransposeHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"transpose_dims": physical_shape_operand,
"output": physical_output_operand
"output": physical_output_operand,
}
return mapping
......@@ -3,11 +3,11 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
__all__ = ['UnaryElementwiseHandler']
__all__ = ["UnaryElementwiseHandler"]
@operator_registry.register(torch.Tensor.to)
......@@ -33,9 +33,9 @@ class UnaryElementwiseHandler(MetaInfoNodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_input_operand = OperationData(
name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "output": physical_output}
......
......@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, ViewGenerator
__all__ = ['ViewHandler']
__all__ = ["ViewHandler"]
@operator_registry.register(torch.Tensor.reshape)
......@@ -38,7 +38,7 @@ class ViewHandler(NodeHandler):
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
target_shape = self.node._meta_data.shape
physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
physical_shape_operand = OperationData(name="tgt_shape", type=OperationDataType.ARG, data=target_shape)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
......@@ -46,7 +46,7 @@ class ViewHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"tgt_shape": physical_shape_operand,
"output": physical_output_operand
"output": physical_output_operand,
}
return mapping
import copy
import operator
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, WhereGenerator
__all__ = ['WhereHandler']
__all__ = ["WhereHandler"]
@operator_registry.register(torch.where)
......@@ -28,27 +27,28 @@ class WhereHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_condition_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_x_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.node.args[1]._meta_data)
physical_y_operand = OperationData(name=str(self.node.args[2]),
type=OperationDataType.ARG,
data=self.node.args[2]._meta_data)
physical_condition_operand = OperationData(
name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
)
physical_x_operand = OperationData(
name=str(self.node.args[1]), type=OperationDataType.ARG, data=self.node.args[1]._meta_data
)
physical_y_operand = OperationData(
name=str(self.node.args[2]), type=OperationDataType.ARG, data=self.node.args[2]._meta_data
)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
physical_mapping = {
"condition": physical_condition_operand,
"x": physical_x_operand,
"y": physical_y_operand,
"output": physical_output
"output": physical_output,
}
logical_shape_for_all = self.node._meta_data.shape
logical_mapping = {}
for key, physical_operand in physical_mapping.items():
logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand,
logical_shape_for_all)
logical_mapping[key] = self.convert_physical_operand_to_logical_operand(
physical_operand, logical_shape_for_all
)
return logical_mapping, physical_mapping
......@@ -64,7 +64,8 @@ class WhereHandler(NodeHandler):
logical_shape = logical_op_data_mapping[key].logical_shape
physical_shape = physical_op_data_mapping[key].logical_shape
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec, logical_shape, physical_shape)
logical_sharding_spec, logical_shape, physical_shape
)
strategy.sharding_specs.pop(logical_op_data_mapping[key])
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
......
from dataclasses import dataclass
from enum import Enum
__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption']
__all__ = ["SolverOptions", "SolverPerference", "DataloaderOption", "ShardOption"]
class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
STANDARD = 0
DP = 1
TP = 2
......@@ -25,6 +26,7 @@ class ShardOption(Enum):
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
"""
STANDARD = 0
SHARD = 1
SHARD_LAST_AXIS = 2
......@@ -35,6 +37,7 @@ class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
REPLICATED = 0
DISTRIBUTED = 1
......@@ -44,6 +47,7 @@ class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
shard_option: ShardOption = ShardOption.STANDARD
......@@ -10,7 +10,6 @@ from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import (
BCAST_FUNC_OP,
ELEMENTWISE_FUNC_OP,
ELEMENTWISE_METHOD_OP,
ELEMENTWISE_MODULE_OP,
......@@ -18,13 +17,14 @@ from .constants import (
RESHAPE_METHOD_OP,
)
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
__all__ = ["OperationDataType", "OperationData", "TrainCycleItem", "MemoryCost", "ShardingStrategy", "StrategiesVector"]
class OperationDataType(Enum):
"""
An operation can come from the argument list of an operator or the parameter list of a module.
"""
INPUT = 0
ARG = 1
PARAM = 2
......@@ -43,6 +43,7 @@ class OperationData:
data (Any): the value for this data, usually it is a meta tensor.
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
"""
name: str
type: OperationDataType
data: Any
......@@ -69,13 +70,13 @@ class OperationData:
self.logical_shape = _infer_logical_shape(self.data)
def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})'
return f"OperationData(name={self.name}, type={self.type})"
def __eq__(self, other) -> bool:
return other.name == self.name
def __hash__(self) -> int:
return hash(f'{self.name}')
return hash(f"{self.name}")
@dataclass
......@@ -88,6 +89,7 @@ class TrainCycleItem:
fwd (float): the item for the forward pass
bwd (float): the item for the backward pass
"""
fwd: Any
bwd: Any
total: Any
......@@ -104,6 +106,7 @@ class MemoryCost:
temp (int): the memory cost incurred by the temporary tensors in bytes.
buffer (int): the memory cost incurred by the module buffer in bytes.
"""
activation: int = 0
parameter: int = 0
temp: int = 0
......@@ -120,6 +123,7 @@ class CommType(Enum):
HOOK: the communication action is used to do the grad all reduce.
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
"""
BEFORE = 0
AFTER = 1
HOOK = 2
......@@ -137,6 +141,7 @@ class CommAction:
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
because the args of node may be changed by graph transform passes.
"""
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
......@@ -156,6 +161,7 @@ class ShardingStrategy:
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
"""
name: str
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
compute_cost: TrainCycleItem = None
......@@ -200,7 +206,6 @@ class ShardingStrategy:
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
def clone(self):
def _deepcopy_dict_vals(data: Dict):
return {k: deepcopy(v) for k, v in data.items()}
......@@ -209,31 +214,34 @@ class ShardingStrategy:
# Consider the examples below:
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
communication_actions = _deepcopy_dict_vals(
self.communication_actions) if self.communication_actions is not None else None
communication_actions = (
_deepcopy_dict_vals(self.communication_actions) if self.communication_actions is not None else None
)
# same reason as communication_actions
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
compute_cost = deepcopy(self.compute_cost)
communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost)
return ShardingStrategy(name=self.name,
sharding_specs=sharding_specs,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
communication_actions=communication_actions,
resharding_costs=resharding_costs)
return ShardingStrategy(
name=self.name,
sharding_specs=sharding_specs,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
communication_actions=communication_actions,
resharding_costs=resharding_costs,
)
class StrategiesVector(list):
'''
"""
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node (Node): node for which the list of sharding strategies are generated.
'''
"""
def __init__(self, node: Node):
super().__init__()
......@@ -245,7 +253,7 @@ class StrategiesVector(list):
def check_merge(self):
merge_label = False
if self.node.op == 'call_module':
if self.node.op == "call_module":
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
......@@ -255,7 +263,7 @@ class StrategiesVector(list):
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
if self.node.op == 'call_function':
if self.node.op == "call_function":
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
......@@ -267,7 +275,7 @@ class StrategiesVector(list):
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
if self.node.op == 'call_method':
if self.node.op == "call_method":
# we could merge reshape op, because their computation costs are negligible.
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
if method in RESHAPE_METHOD_OP:
......
......@@ -3,4 +3,4 @@ from .graph_analysis import GraphAnalyser
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
__all__ = ["GraphAnalyser", "Solver", "StrategiesConstructor", "CostGraph"]
......@@ -4,18 +4,18 @@ from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
class CostGraph:
'''
"""
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
element-wise operators, transpose, and reduction, into their following nodes. The merging information will
be given by the StrategiesVector depending on the type of target node and following nodes.
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
'''
"""
def __init__(self, leaf_strategies, simplify=True, forward_only=False):
self.leaf_strategies = leaf_strategies
......@@ -39,10 +39,10 @@ class CostGraph:
target_node_list.remove(element)
def _build_cost_graph(self):
'''
"""
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
'''
"""
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
......@@ -84,13 +84,13 @@ class CostGraph:
if _check_tensor_in_node(node._meta_data):
children_nodes.append(node)
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
setattr(dst_node, "parents", parent_nodes)
setattr(dst_node, "children", children_nodes)
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
# we only merge node pairs which src node has a tensor element inside.
# This is necessay because the node without a tensor element inside will not
# This is necessary because the node without a tensor element inside will not
# be assigned any strategy.
if _check_tensor_in_node(followed_node._meta_data):
self.merge_pair.append((followed_node, dst_node))
......@@ -99,7 +99,7 @@ class CostGraph:
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
'''
"""
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
......@@ -119,7 +119,7 @@ class CostGraph:
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
"""
# build merge_map
merge_map = {}
for src_index, _ in enumerate(src_node.strategies_vector):
......@@ -196,7 +196,7 @@ class CostGraph:
if not self.simplify:
return
self.merge_pair.reverse()
for (src_node, dst_node) in self.merge_pair:
for src_node, dst_node in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
......
......@@ -7,7 +7,7 @@ from torch.fx.node import Node
from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"]
@dataclass
......@@ -15,6 +15,7 @@ class LiveVariable:
"""
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
name: str
node: Node
is_inplace: bool
......@@ -55,6 +56,7 @@ class LiveStage:
"""
LiveStage is a data structure to record the living variables at this current node.
"""
name: str
node: Node
all_live_vars: LiveVariableVector
......@@ -62,7 +64,6 @@ class LiveStage:
class GraphAnalyser:
def __init__(self, gm: GraphModule):
self._gm = gm
self._graph = gm.graph
......@@ -83,7 +84,7 @@ class GraphAnalyser:
def liveness_analysis(self) -> List[LiveStage]:
"""
Analyse the graph to obtain the variable liveness information. This function returns
Analyses the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
"""
compute_nodes = self.graph.nodes
......@@ -91,7 +92,7 @@ class GraphAnalyser:
# checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage.
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
# this can be different from the `checked list`` as some variables may be destroyed prior to this stage.
# unique: record the unique live variables only exist until the current stage.
# this is different from `all list` as some variables are duplicated.
checked_variables = LiveVariableVector()
......@@ -103,20 +104,20 @@ class GraphAnalyser:
# find new living variables #
#############################
# detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplciate var
# if it is an in-place op, we would deem it as a duplicate var
is_inplace = False
if node.op == 'call_function':
if node.op == "call_function":
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
if node.kwargs.get('inplace', False):
if node.kwargs.get("inplace", False):
is_inplace = True
elif node.op == 'call_module':
elif node.op == "call_module":
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node)
if getattr(module, 'inplace', False):
if getattr(module, "inplace", False):
is_inplace = True
# add the output var
meta = getattr(node, '_meta_data', None)
getattr(node, "_meta_data", None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
......@@ -138,10 +139,12 @@ class GraphAnalyser:
# this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict
stage = LiveStage(name=node.name,
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy())
stage = LiveStage(
name=node.name,
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy(),
)
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):
......
......@@ -21,34 +21,35 @@ try:
import pulp
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
warnings.warn(f'please install the pulp')
warnings.warn(f"please install the pulp")
__all___ = ['Solver']
__all___ = ["Solver"]
class Solver:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3,
verbose=False):
'''
def __init__(
self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3,
verbose=False,
):
"""
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
graph_analyser: graph_analyser will analyses the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
"""
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
......@@ -75,11 +76,11 @@ class Solver:
self.verbose = verbose
def _recover_merged_node_strategy(self):
'''
"""
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
"""
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
......@@ -98,9 +99,9 @@ class Solver:
return node_index_dict
def _prepare_data_for_solver(self):
'''
"""
Extract information from components for solver.
'''
"""
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
......@@ -190,23 +191,40 @@ class Solver:
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None,
verbose=True):
return (
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np,
self.verbose,
)
def _call_solver_serialized_args(
self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None,
verbose=True,
):
"""
Call the solver with serialized arguments.
"""
......@@ -235,18 +253,18 @@ class Solver:
s_follow = following_nodes
s_alias = alias_set
E = edge_pairs.reshape((-1, 2)) # noqa
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
for i, j in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
r.append(resharding_costs[pt : pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
......@@ -268,7 +286,6 @@ class Solver:
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
......@@ -277,9 +294,9 @@ class Solver:
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
c.append(compute_costs[pt : pt + length])
d.append(communication_costs[pt : pt + length])
m.append(memory_costs[pt : pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
......@@ -319,7 +336,7 @@ class Solver:
e = []
num_edges = 0
map_edge_to_idx = {}
for (idx, (i, j)) in enumerate(E):
for idx, (i, j) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
......@@ -340,7 +357,7 @@ class Solver:
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for idx, value, fix in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
......@@ -393,7 +410,7 @@ class Solver:
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
for idx, (i, j) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
......@@ -402,13 +419,13 @@ class Solver:
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
......@@ -434,7 +451,8 @@ class Solver:
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
onlyAvailable=True
), "Please install ILP solvers by 'sudo apt install coinor-cbc'"
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
......@@ -444,13 +462,13 @@ class Solver:
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
raise RuntimeError(
"Cannot run the function under the given memory budget. " "Please increase the memory budget."
)
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
......@@ -458,7 +476,7 @@ class Solver:
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
for idx, (i, j) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
......
import builtins
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx import Graph, Node
from torch.fx import Graph
from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler,
......@@ -14,13 +8,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
operator_registry,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh
from ..options import DataloaderOption, SolverOptions
__all__ = ['StrategiesConstructor']
__all__ = ["StrategiesConstructor"]
class StrategiesConstructor:
......@@ -35,7 +28,7 @@ class StrategiesConstructor:
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
......@@ -46,11 +39,11 @@ class StrategiesConstructor:
self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector):
'''
"""
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place.
'''
"""
name_checklist = []
remove_list = []
for strategy in strategies_vector:
......@@ -62,7 +55,6 @@ class StrategiesConstructor:
strategies_vector.remove(strategy)
def generate_alias_set(self):
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
......@@ -83,7 +75,7 @@ class StrategiesConstructor:
"""
def _check_no_strategy_for_node(node):
if node.op in ('placeholder', 'get_attr', 'output'):
if node.op in ("placeholder", "get_attr", "output"):
return False
def _check_no_strategy_for_data(data):
......@@ -102,83 +94,93 @@ class StrategiesConstructor:
if _check_no_strategy_for_node(node):
self.no_strategy_nodes.append(node)
pass
# placeholder node
elif node.op == 'placeholder':
elif node.op == "placeholder":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
placeholder_option = 'distributed'
placeholder_option = "distributed"
else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
placeholder_option = 'replicated'
placeholder_handler = PlaceholderHandler(node,
self.device_mesh,
strategies_vector,
placeholder_option=placeholder_option)
assert (
self.solver_options.dataloader_option == DataloaderOption.REPLICATED
), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
placeholder_option = "replicated"
placeholder_handler = PlaceholderHandler(
node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option
)
placeholder_handler.register_strategy()
# get_attr node
elif node.op == 'get_attr':
getattr_handler = GetattrHandler(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
elif node.op == "get_attr":
getattr_handler = GetattrHandler(
node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
getattr_handler.register_strategy()
# call_module node
elif node.op == 'call_module':
elif node.op == "call_module":
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
handler = operator_registry.get(submod_type)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler = operator_registry.get(submod_type)(
node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
handler.register_strategy()
# attach strategies_info to node
if hasattr(handler, 'strategies_info'):
setattr(node, 'strategies_info', handler.strategies_info)
if hasattr(handler, "strategies_info"):
setattr(node, "strategies_info", handler.strategies_info)
# call_function node
elif node.op == 'call_function':
elif node.op == "call_function":
target = node.target
handler = operator_registry.get(target)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler = operator_registry.get(target)(
node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
handler.register_strategy()
# attach strategies_info to node
if hasattr(handler, 'strategies_info'):
setattr(node, 'strategies_info', handler.strategies_info)
if hasattr(handler, "strategies_info"):
setattr(node, "strategies_info", handler.strategies_info)
# call_method node
elif node.op == 'call_method':
elif node.op == "call_method":
method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler = operator_registry.get(method)(
node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference,
)
handler.register_strategy()
# attach strategies_info to node
if hasattr(handler, 'strategies_info'):
setattr(node, 'strategies_info', handler.strategies_info)
if hasattr(handler, "strategies_info"):
setattr(node, "strategies_info", handler.strategies_info)
# output node
elif node.op == 'output':
elif node.op == "output":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
output_option = 'distributed'
output_option = "distributed"
else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
output_option = 'replicated'
assert (
self.solver_options.dataloader_option == DataloaderOption.REPLICATED
), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
output_option = "replicated"
output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy()
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
setattr(node, "strategies_vector", strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
......
......@@ -17,9 +17,21 @@ from .sharding import (
)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map',
'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict'
"BroadcastType",
"get_broadcast_shape",
"is_broadcastable",
"recover_sharding_spec_for_broadcast_shape",
"generate_resharding_costs",
"generate_sharding_spec",
"ignore_sharding_exception",
"check_sharding_spec_validity" "transpose_partition_dim",
"update_partition_dim",
"enumerate_all_possible_1d_sharding",
"enumerate_all_possible_2d_sharding",
"generate_sharding_size",
"comm_actions_for_oprands",
"pytree_map",
"detect_reshape_mapping",
"check_keep_sharding_status",
"infer_output_dim_partition_dict",
]
......@@ -14,14 +14,17 @@ from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
'comm_actions_for_oprands'
"BroadcastType",
"is_broadcastable",
"get_broadcast_shape",
"recover_sharding_spec_for_broadcast_shape",
"comm_actions_for_oprands",
]
class BroadcastType(Enum):
EQUAL = auto()
PADDDING = auto()
PADDING = auto()
MULTIPLE = auto()
......@@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
"""
Compute the broadcast shape given two shapes.
"""
assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
assert is_broadcastable(shape1, shape2), f"{shape1} and {shape2} are not broadcastable"
shape1_reverse = shape1[::-1]
shape2_reverse = shape2[::-1]
min_common_dim = min(len(shape1), len(shape2))
......@@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
assert logical_num_dims >= physical_num_dims, \
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
assert (
logical_num_dims >= physical_num_dims
), "The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!"
# track the dim and its broadcasting type
logical_dim_broadcast_info = {}
......@@ -69,24 +73,25 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
for i in range(logical_num_dims):
# get the trailing dim size
logical_dim_idx = logical_num_dims - i - 1
phyiscal_dim_idx = physical_num_dims - i - 1
physical_dim_idx = physical_num_dims - i - 1
logical_dim_size = logical_shape[logical_dim_idx]
if phyiscal_dim_idx >= 0:
physical_dim_size = physical_shape[phyiscal_dim_idx]
if physical_dim_idx >= 0:
physical_dim_size = physical_shape[physical_dim_idx]
if physical_dim_size == logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE
else:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDING
return logical_dim_broadcast_info
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
def recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, physical_shape: torch.Size
) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
......@@ -117,22 +122,25 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
for shape_dim, mesh_dim in logical_dim_partition.items():
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
if logical_broadcast_type == BroadcastType.PADDING or logical_broadcast_type == BroadcastType.MULTIPLE:
removed_dims.extend(mesh_dim)
else:
# get the corresponding physical dim
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
physical_dim_partition[physical_dim] = mesh_dim
physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition)
physical_sharding_spec = ShardingSpec(
device_mesh=logical_sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition,
)
return physical_sharding_spec, removed_dims
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
sharding_spec: ShardingSpec) -> CommAction:
def comm_actions_for_oprands(
node: Node, removed_dims: List[int], op_data: OperationData, sharding_spec: ShardingSpec
) -> CommAction:
"""
This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape.
......@@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0]
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
sharding_spec=sharding_spec,
logical_process_axis=removed_dims)
comm_spec = CommSpec(
comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
sharding_spec=sharding_spec,
logical_process_axis=removed_dims,
)
if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK
else:
......@@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
for index, arg in enumerate(node.args):
if op_data.name == str(arg):
arg_index = index
assert arg_index >= 0, f'op_data should be an argument of node.'
assert arg_index >= 0, f"op_data should be an argument of node."
comm_action = CommAction(
comm_spec=comm_spec,
comm_type=comm_type,
......
......@@ -14,11 +14,12 @@ from colossalai.tensor.sharding_spec import ShardingSpec
from ..constants import INFINITY_COST
__all__ = ['generate_sharding_spec', 'generate_resharding_costs']
__all__ = ["generate_sharding_spec", "generate_resharding_costs"]
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
def generate_sharding_spec(
input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, dim_partition_dict: Dict[int, List[int]]
) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
......@@ -30,7 +31,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
"""
if isinstance(input_, Node):
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
assert hasattr(input_, "_meta_data"), f"The given node has no attribute _meta_data"
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
......@@ -38,24 +39,27 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
shape = input_.shape
else:
raise TypeError(
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
f"We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected."
)
for dim_index, sharding_index_list in dim_partition_dict.items():
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
assert shape[
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
assert (
shape[dim_index] % sharding_size == 0
), f"we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions."
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec
def generate_resharding_costs(nodes: List[Node],
sharding_specs: List[ShardingSpec],
count_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None,
index=None):
'''
def generate_resharding_costs(
nodes: List[Node],
sharding_specs: List[ShardingSpec],
count_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None,
index=None,
):
"""
Compute the resharding costs with this specific strategy.
Argument:
......@@ -63,7 +67,7 @@ def generate_resharding_costs(nodes: List[Node],
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
'''
"""
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
......@@ -76,38 +80,39 @@ def generate_resharding_costs(nodes: List[Node],
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
if not isinstance(input_sharding_spec, ShardingSpec):
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
assert isinstance(input_sharding_spec, list), "only ShardingSpec or List[ShardingSpec] is expected."
input_sharding_spec = input_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
assert isinstance(input_sharding_spec, ShardingSpec), f"The input node should NOT be a tuple of tensor."
try:
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
input_sharding_spec, input_spec
)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
except AssertionError as e:
warnings.warn(f'{e}')
warnings.warn(f"{e}")
resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20):
'''
"""
Find the largest repeat blocks in the graph, whose length is larger than the threshold.
Args:
gm (GraphModule): the graph module to be analyzed.
common_length_threshold (int): the threshold of the repeat block length.
'''
"""
# graph = gm.graph
def _process_args(args):
new_args = []
for arg in args:
if hasattr(arg, '_meta_data'):
if hasattr(arg, "_meta_data"):
meta_data = arg._meta_data
else:
meta_data = arg
......@@ -145,7 +150,7 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
return False
for index, node in enumerate(node_list):
if node.op == 'call_module':
if node.op == "call_module":
target = node.target
submod = root_module.get_submodule(target)
submod_type = type(submod)
......@@ -155,12 +160,12 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
new_args = _process_args(node.args)
if node.op != 'get_attr':
if node.op != "get_attr":
hash_key = (node.op, target, *new_args)
else:
hash_key = (node.op,)
setattr(node, 'hash_key', hash_key)
setattr(node, "hash_key", hash_key)
hash_value_to_node_dict = {}
......@@ -179,7 +184,7 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
# the comparison will be triggered if a common node appears
if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2:
start_index_list = hash_value_to_node_dict[hash(node.hash_key)]
check_block_list = [node_list[start:start + max_common_length] for start in start_index_list]
check_block_list = [node_list[start : start + max_common_length] for start in start_index_list]
common_label = True
if not _all_equal(check_block_list, _check_node_list_equal):
......@@ -201,6 +206,6 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
# recover common subgraph from the index
common_blocks = []
for start in common_blocks_index:
common_blocks.append(node_list[start:start + max_common_length])
common_blocks.append(node_list[start : start + max_common_length])
return common_blocks
import functools
from typing import Any, Callable, Dict, List, Tuple, Type, Union
from typing import Any, Callable, Tuple, Type, Union
import torch
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
__all__ = ['ignore_sharding_exception', 'pytree_map']
__all__ = ["ignore_sharding_exception", "pytree_map"]
def ignore_sharding_exception(func):
......@@ -46,31 +46,34 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
num_devices_in_col = sharding_spec.device_mesh.shape[0]
num_devices_in_row = sharding_spec.device_mesh.shape[1]
assert (
sharding_len == tensor_num_dim
), f"The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape})."
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'):
devices_str = str(dim_spec).lstrip('S')
if str(dim_spec).startswith("S"):
devices_str = str(dim_spec).lstrip("S")
num_devices = 1
if '0' in devices_str:
if "0" in devices_str:
num_devices *= num_devices_in_col
if '1' in devices_str:
if "1" in devices_str:
num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
assert (
dim_size >= num_devices and dim_size % num_devices == 0
), f"The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices."
# make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape, \
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
assert (
sharding_spec.entire_shape == tensor.shape
), f"The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}"
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
......
......@@ -6,12 +6,13 @@ import torch
class PreviousStatus(Enum):
"""
This class shows the status of previous comparision.
This class shows the status of previous comparison.
"""
RESET = 0
# ORIGIN means the dimension size of original tensor is larger in the previous comparision.
# ORIGIN means the dimension size of original tensor is larger in the previous comparison.
ORIGIN = 1
# TGT means the dimension size of target tensor is larger in the previous comparision.
# TGT means the dimension size of target tensor is larger in the previous comparison.
TGT = 2
......@@ -91,7 +92,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
tgt_index += 1
if previous_label == PreviousStatus.TGT:
# if the target dimension size is larger in the previous comparision, which means
# if the target dimension size is larger in the previous comparison, which means
# the origin dimension size has already accumulated larger than target dimension size, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
......@@ -111,7 +112,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
origin_index += 1
if previous_label == PreviousStatus.ORIGIN:
# if the origin element is larger in the previous comparision, which means
# if the origin element is larger in the previous comparison, which means
# the target element has already accumulated larger than origin element, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
......@@ -130,8 +131,9 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
return reshape_mapping_dict
def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool:
def check_keep_sharding_status(
input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]
) -> bool:
"""
This method is used to check whether the reshape operation could implement without converting
the input to fully replicated status.
......@@ -139,7 +141,7 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
Rule:
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
the function will return false.
To illustrate this issue, there are two cases to analyse:
To illustrate this issue, there are two cases to analyze:
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
operation without distributed tensor.
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape
......@@ -172,14 +174,16 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
return True
def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]],
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]:
def infer_output_dim_partition_dict(
input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]
) -> Dict[Tuple[int], Tuple[int]]:
"""
This method is used to infer the output dim partition dict for a reshape operation,
given the input dim partition dict and reshape mapping dict.
"""
assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \
'we only infer output dim partition dict for the reshape operation could keep sharding spec.'
assert check_keep_sharding_status(
input_dim_partition_dict, reshape_mapping_dict
), "we only infer output dim partition dict for the reshape operation could keep sharding spec."
sharded_dims = list(input_dim_partition_dict.keys())
output_dim_partition_dict = {}
for input_dims, output_dims in reshape_mapping_dict.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment