Unverified Commit 8283e95d authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] collated all deprecated files (#1700)

* [autoparallel] collated all deprecated files

* polish code
parent e2355d01
...@@ -18,174 +18,6 @@ class CostGraph: ...@@ -18,174 +18,6 @@ class CostGraph:
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
''' '''
def __init__(self, leaf_strategies, simplify=True):
self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
# stores number of strategies in each node
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes
self.extra_node_costs = {}
self.following_dict = {}
self.simplify = simplify
self._build_cost_graph()
def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, [])
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)
def _build_cost_graph(self):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
'''
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
for strategies_vector in self.leaf_strategies:
# build edge_cost
dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.strategies_vector)):
edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
setattr(dst_node, 'children', strategies_vector.successor_nodes)
self._remove_invalid_node(dst_node, 'parents')
self._remove_invalid_node(dst_node, 'children')
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
self.merge_pair.append((followed_node, dst_node))
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
'''
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy.
3. Build connections between new node pairs, and remove the src_node after all consumer nodes
detached from it.
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
src_node_index = dst_node.parents.index(src_node)
# build merge_map
merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector):
min_cost = INFINITY_COST
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
if resharding_cost <= min_cost:
min_cost = resharding_cost
lowest_cost_index = dst_index
merge_map[src_index] = lowest_cost_index
# extra_node_cost for src node
self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
for src_index, strategy in enumerate(src_node.strategies_vector):
target_strate_index = merge_map[src_index]
target_strategy = dst_node.strategies_vector[target_strate_index]
self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index]
if dst_node in self.extra_node_costs:
self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
# add new node pair to cost graph
for child_node in dst_node.children:
new_node_pair = (src_node, child_node)
old_node_pair = (dst_node, child_node)
if new_node_pair in self.edge_costs:
continue
edge_cost = {}
for i in range(self.node_lens[src_node]):
for j in range(self.node_lens[child_node]):
dst_strate_index = merge_map[i]
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
if new_node_pair not in self.edge_costs:
self.edge_costs[new_node_pair] = edge_cost
else:
# we should accumulate the resharding costs if args of child node contain
# both src node and dst node.
for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
# connect src node and children of dst node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
self.edge_costs.pop((src_node, dst_node))
for child_node in dst_node.children:
if child_node not in src_node.children:
src_node.children.append(child_node)
if src_node not in child_node.parents:
child_node.parents.append(src_node)
# remove dst node from cost graph when dst node has no producer.
if len(dst_node.parents) == 0:
child_node.parents.remove(dst_node)
node_pair = (dst_node, child_node)
self.edge_costs.pop(node_pair)
if len(dst_node.parents) == 0:
self.following_dict[dst_node] = src_node
dst_node.children = []
def _reindexing_src(self, src):
if src not in self.following_dict:
return src
return self._reindexing_src(self.following_dict[src])
def simplify_graph(self):
if not self.simplify:
return
self.merge_pair.reverse()
for (src_node, dst_node) in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
for dst, src in self.following_dict.items():
reindexing_following_dict[dst] = self._reindexing_src(src)
self.following_dict = reindexing_following_dict
class CostGraph_V2:
'''
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes.
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
'''
def __init__(self, leaf_strategies, simplify=True, forward_only=False): def __init__(self, leaf_strategies, simplify=True, forward_only=False):
self.leaf_strategies = leaf_strategies self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
......
from .dot_handler import LinearFunctionHandler, LinearModuleHandler
from .layer_norm_handler import LayerNormModuleHandler
from .batch_norm_handler import BatchNormModuleHandler
from .conv_handler import ConvModuleHandler, ConvFunctionHandler
from .where_handler import WhereHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .reshape_handler import ReshapeHandler
from .placeholder_handler import PlacehodlerHandler
from .output_handler import OuputHandler
from .normal_pooling_handler import NormPoolingHandler
__all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler',
'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler',
'OuputHandler', 'WhereHandler', 'NormPoolingHandler'
]
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import BatchNormStrategyGenerator, StrategyGenerator_V2 from ..strategy import BatchNormStrategyGenerator, StrategyGenerator
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
...@@ -17,7 +17,7 @@ class BatchNormModuleHandler(ModuleHandler): ...@@ -17,7 +17,7 @@ class BatchNormModuleHandler(ModuleHandler):
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module. A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(BatchNormStrategyGenerator(op_data_mapping, self.device_mesh)) generators.append(BatchNormStrategyGenerator(op_data_mapping, self.device_mesh))
......
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import ConvStrategyGenerator, StrategyGenerator_V2 from ..strategy import ConvStrategyGenerator, StrategyGenerator
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
...@@ -17,7 +17,7 @@ class ConvModuleHandler(ModuleHandler): ...@@ -17,7 +17,7 @@ class ConvModuleHandler(ModuleHandler):
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module. A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh)) generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh))
...@@ -47,7 +47,7 @@ class ConvModuleHandler(ModuleHandler): ...@@ -47,7 +47,7 @@ class ConvModuleHandler(ModuleHandler):
mapping['bias'] = physical_bias_operand mapping['bias'] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy_V2): def post_process(self, strategy: ShardingStrategy):
""" """
Convert the sharding spec of the weight parameter back to its original shape. Convert the sharding spec of the weight parameter back to its original shape.
""" """
...@@ -78,7 +78,7 @@ class ConvFunctionHandler(NodeHandler): ...@@ -78,7 +78,7 @@ class ConvFunctionHandler(NodeHandler):
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions. A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh)) generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh))
...@@ -120,7 +120,7 @@ class ConvFunctionHandler(NodeHandler): ...@@ -120,7 +120,7 @@ class ConvFunctionHandler(NodeHandler):
mapping['bias'] = physical_bias_operand mapping['bias'] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy_V2): def post_process(self, strategy: ShardingStrategy):
""" """
Convert the sharding spec of the weight parameter back to its original shape. Convert the sharding spec of the weight parameter back to its original shape.
""" """
......
...@@ -2,8 +2,8 @@ import torch ...@@ -2,8 +2,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.tensor.sharding_spec import ShardingException from colossalai.tensor.sharding_spec import ShardingException
from .node_handler import ModuleHandler, NodeHandler from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2, BatchedMatMulStrategyGenerator from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator, BatchedMatMulStrategyGenerator
from typing import List, Dict, Union from typing import List, Dict, Union
from .registry import operator_registry from .registry import operator_registry
from copy import deepcopy from copy import deepcopy
...@@ -18,7 +18,7 @@ class LinearModuleHandler(ModuleHandler): ...@@ -18,7 +18,7 @@ class LinearModuleHandler(ModuleHandler):
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
...@@ -53,7 +53,7 @@ class LinearModuleHandler(ModuleHandler): ...@@ -53,7 +53,7 @@ class LinearModuleHandler(ModuleHandler):
mapping['bias'] = physical_bias_operand mapping['bias'] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]: def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
""" """
Convert the sharding spec from the logical shape to the physical shape. Convert the sharding spec from the logical shape to the physical shape.
""" """
...@@ -101,7 +101,7 @@ class LinearFunctionHandler(NodeHandler): ...@@ -101,7 +101,7 @@ class LinearFunctionHandler(NodeHandler):
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
...@@ -140,7 +140,7 @@ class LinearFunctionHandler(NodeHandler): ...@@ -140,7 +140,7 @@ class LinearFunctionHandler(NodeHandler):
mapping['bias'] = physical_bias_operand mapping['bias'] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy_V2): def post_process(self, strategy: ShardingStrategy):
""" """
Convert the sharding spec of the weight parameter back to its original shape. Convert the sharding spec of the weight parameter back to its original shape.
""" """
...@@ -200,7 +200,7 @@ class BMMFunctionHandler(NodeHandler): ...@@ -200,7 +200,7 @@ class BMMFunctionHandler(NodeHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping return mapping
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
generators = [] generators = []
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
......
import torch import torch
from .node_handler import NodeHandler from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator_V2 from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
import operator import operator
...@@ -15,7 +15,7 @@ class GetItemHandler(NodeHandler): ...@@ -15,7 +15,7 @@ class GetItemHandler(NodeHandler):
A GetItemHandler which deals with the sharding strategies for operator.getitem. A GetItemHandler which deals with the sharding strategies for operator.getitem.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
if isinstance(op_data_mapping["input"].data, torch.Tensor): if isinstance(op_data_mapping["input"].data, torch.Tensor):
......
import torch import torch
from .node_handler import ModuleHandler from .node_handler import ModuleHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import LayerNormGenerator, StrategyGenerator_V2 from ..strategy import LayerNormGenerator, StrategyGenerator
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
...@@ -14,7 +14,7 @@ class LayerNormModuleHandler(ModuleHandler): ...@@ -14,7 +14,7 @@ class LayerNormModuleHandler(ModuleHandler):
A LayerNormModuleHandler which deals with the sharding strategies for nn.LayerNorm module. A LayerNormModuleHandler which deals with the sharding strategies for nn.LayerNorm module.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(LayerNormGenerator(op_data_mapping, self.device_mesh)) generators.append(LayerNormGenerator(op_data_mapping, self.device_mesh))
......
...@@ -3,8 +3,8 @@ from torch.fx.node import Node ...@@ -3,8 +3,8 @@ from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from typing import Dict, List, Union from typing import Dict, List, Union
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem from ..sharding_strategy import ShardingStrategy, StrategiesVector, OperationData, TrainCycleItem
from ..strategy import StrategyGenerator_V2 from ..strategy import StrategyGenerator
from .._utils import generate_resharding_costs from .._utils import generate_resharding_costs
...@@ -30,7 +30,7 @@ class NodeHandler(ABC): ...@@ -30,7 +30,7 @@ class NodeHandler(ABC):
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.strategies_vector = strategies_vector self.strategies_vector = strategies_vector
def update_resharding_cost(self, strategy: ShardingStrategy_V2) -> None: def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
""" """
Compute the resharding costs and save the costs in the ShardingStrategy object. Compute the resharding costs and save the costs in the ShardingStrategy object.
""" """
...@@ -97,13 +97,13 @@ class NodeHandler(ABC): ...@@ -97,13 +97,13 @@ class NodeHandler(ABC):
return self.strategies_vector return self.strategies_vector
def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]: def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# tranform the strategy generated # tranform the strategy generated
# e.g. to process the sharding strategy for the transposed weights # e.g. to process the sharding strategy for the transposed weights
return strategy return strategy
@abstractmethod @abstractmethod
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
""" """
Define which generators should be used by this NodeHandler object. Define which generators should be used by this NodeHandler object.
""" """
......
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator_V2 from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] __all__ = ['NormPoolingHandler']
@operator_registry.register(torch.nn.MaxPool1d) @operator_registry.register(torch.nn.MaxPool1d)
...@@ -20,7 +20,7 @@ class NormPoolingHandler(ModuleHandler): ...@@ -20,7 +20,7 @@ class NormPoolingHandler(ModuleHandler):
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module. A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(NormalPoolStrategyGenerator(op_data_mapping, self.device_mesh)) generators.append(NormalPoolStrategyGenerator(op_data_mapping, self.device_mesh))
......
import torch import torch
from .node_handler import NodeHandler from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from colossalai.auto_parallel.solver.strategy import StrategyGenerator_V2 from colossalai.auto_parallel.solver.strategy import StrategyGenerator
from colossalai.auto_parallel.solver.strategy.output_generator import OutputGenerator from colossalai.auto_parallel.solver.strategy.output_generator import OutputGenerator
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
...@@ -14,7 +14,7 @@ class OuputHandler(NodeHandler): ...@@ -14,7 +14,7 @@ class OuputHandler(NodeHandler):
A OuputHandler which deals with the sharding strategies for Output Node. A OuputHandler which deals with the sharding strategies for Output Node.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node)) generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node))
......
import torch import torch
from .node_handler import NodeHandler from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from colossalai.auto_parallel.solver.strategy import StrategyGenerator_V2 from colossalai.auto_parallel.solver.strategy import StrategyGenerator
from colossalai.auto_parallel.solver.strategy.placeholder_generator import PlaceholderGenerator from colossalai.auto_parallel.solver.strategy.placeholder_generator import PlaceholderGenerator
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
...@@ -14,7 +14,7 @@ class PlacehodlerHandler(NodeHandler): ...@@ -14,7 +14,7 @@ class PlacehodlerHandler(NodeHandler):
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node. A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(PlaceholderGenerator(op_data_mapping, self.device_mesh)) generators.append(PlaceholderGenerator(op_data_mapping, self.device_mesh))
......
import torch import torch
from .node_handler import NodeHandler from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import ReshapeGenerator, StrategyGenerator_V2 from ..strategy import ReshapeGenerator, StrategyGenerator
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
import operator import operator
__all__ = ['ReshapeHandler_V2'] __all__ = ['ReshapeHandler']
@operator_registry.register(torch.reshape) @operator_registry.register(torch.reshape)
@operator_registry.register(torch.flatten) @operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.permute) @operator_registry.register(torch.Tensor.permute)
class ReshapeHandler_V2(NodeHandler): class ReshapeHandler(NodeHandler):
""" """
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
......
import torch import torch
from .node_handler import NodeHandler from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import UnaryElementwiseGenerator, StrategyGenerator_V2 from ..strategy import UnaryElementwiseGenerator, StrategyGenerator
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
import operator import operator
__all__ = ['UnaryElementwiseHandler_V2'] __all__ = ['UnaryElementwiseHandler']
@operator_registry.register(torch.abs) @operator_registry.register(torch.abs)
@operator_registry.register(torch.nn.ReLU) @operator_registry.register(torch.nn.ReLU)
class UnaryElementwiseHandler_V2(NodeHandler): class UnaryElementwiseHandler(NodeHandler):
""" """
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(UnaryElementwiseGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) generators.append(UnaryElementwiseGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
......
import torch import torch
from .node_handler import NodeHandler from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import WhereGenerator, StrategyGenerator_V2 from ..strategy import WhereGenerator, StrategyGenerator
from .broadcast import recover_sharding_spec_for_broadcast_shape from .broadcast import recover_sharding_spec_for_broadcast_shape
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
...@@ -17,7 +17,7 @@ class WhereHandler(NodeHandler): ...@@ -17,7 +17,7 @@ class WhereHandler(NodeHandler):
A WhereHandler which deals with the sharding strategies for torch.where. A WhereHandler which deals with the sharding strategies for torch.where.
""" """
def get_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator]:
logical_op_data_mapping, _ = self.get_operation_data_mapping() logical_op_data_mapping, _ = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(WhereGenerator(logical_op_data_mapping, self.device_mesh)) generators.append(WhereGenerator(logical_op_data_mapping, self.device_mesh))
...@@ -73,7 +73,7 @@ class WhereHandler(NodeHandler): ...@@ -73,7 +73,7 @@ class WhereHandler(NodeHandler):
self.strategies_vector = list(strategies_vector) self.strategies_vector = list(strategies_vector)
return self.strategies_vector return self.strategies_vector
def post_process(self, strategy: ShardingStrategy_V2): def post_process(self, strategy: ShardingStrategy):
logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping() logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping()
for key in logical_op_data_mapping.keys(): for key in logical_op_data_mapping.keys():
logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]] logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]]
......
...@@ -13,37 +13,7 @@ from typing import Dict, List, Union, Tuple, Any ...@@ -13,37 +13,7 @@ from typing import Dict, List, Union, Tuple, Any
from torch.fx.node import Node from torch.fx.node import Node
from .constants import * from .constants import *
__all__ = ['ShardingStrategy', 'StrategiesVector'] __all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
@dataclass
class ShardingStrategy:
'''
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
and costs information using in solver.
Argument:
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
compute_cost(float): Computation cost to complete this strategy.(default to 0)
communication_cost(float): Communication cost to complete this strategy.(default to 0)
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.(default to None)
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
'''
name: str
# TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
compute_cost: float = 0.
communication_cost: float = 0.
memory_cost: float = 0.
resharding_costs: Dict[Node, List[float]] = None
# sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
# Therefore, we could process them at the specific op(operator.getitem)
input_shardings: List[ShardingSpec] = None
class OperationDataType(Enum): class OperationDataType(Enum):
...@@ -111,7 +81,7 @@ class MemoryCost: ...@@ -111,7 +81,7 @@ class MemoryCost:
@dataclass @dataclass
class ShardingStrategy_V2: class ShardingStrategy:
""" """
ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node. ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node.
...@@ -178,7 +148,7 @@ class ShardingStrategy_V2: ...@@ -178,7 +148,7 @@ class ShardingStrategy_V2:
communication_cost = deepcopy(self.communication_cost) communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost) memory_cost = deepcopy(self.memory_cost)
return ShardingStrategy_V2(name=self.name, return ShardingStrategy(name=self.name,
sharding_specs=sharding_specs, sharding_specs=sharding_specs,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
......
from .strategy_generator import StrategyGenerator_V2 from .strategy_generator import StrategyGenerator
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator from .conv_strategy_generator import ConvStrategyGenerator
from .batch_norm_generator import BatchNormStrategyGenerator from .batch_norm_generator import BatchNormStrategyGenerator
...@@ -11,11 +11,10 @@ from .normal_pooling_generator import NormalPoolStrategyGenerator ...@@ -11,11 +11,10 @@ from .normal_pooling_generator import NormalPoolStrategyGenerator
from .placeholder_generator import PlaceholderGenerator from .placeholder_generator import PlaceholderGenerator
from .output_generator import OutputGenerator from .output_generator import OutputGenerator
__all__ = [ __all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator',
'WhereGenerator', 'ReshapeGenerator', 'NormalPoolStrategyGenerator' 'ReshapeGenerator', 'NormalPoolStrategyGenerator'
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment