Commit 7bc5a8e3 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents e6748d82 0f785cb1
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import (
BCAST_FUNC_OP,
ELEMENTWISE_FUNC_OP,
ELEMENTWISE_METHOD_OP,
ELEMENTWISE_MODULE_OP,
RESHAPE_FUNC_OP,
RESHAPE_METHOD_OP,
)
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
class OperationDataType(Enum):
"""
An operation can come from the argument list of an operator or the parameter list of a module.
"""
INPUT = 0
ARG = 1
PARAM = 2
BUFFER = 3
OUTPUT = 4
@dataclass
class OperationData:
"""
OperationData is the data related to an operator, the data can be the operand or the output.
Args:
name (str): the name of the operation-related data
type (OperationDataType): the type of the operation data
data (Any): the value for this data, usually it is a meta tensor.
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
"""
name: str
type: OperationDataType
data: Any
logical_shape: Tuple[int] = None
def __post_init__(self):
# if no logical shape is specified, use the data shape as the logical shape
if self.logical_shape is None:
def _infer_logical_shape(data: any):
"""
This function is used to infer the logical shape of the data.
"""
if isinstance(data, torch.Tensor):
return data.shape
elif isinstance(data, torch.Size):
return None
elif isinstance(data, (tuple, list)):
data_type = type(data)
return data_type([_infer_logical_shape(d) for d in data])
else:
return None
self.logical_shape = _infer_logical_shape(self.data)
def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})'
def __eq__(self, other) -> bool:
return other.name == self.name
def __hash__(self) -> int:
return hash(f'{self.name}')
@dataclass
class TrainCycleItem:
"""
TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass
in a training iteration.
Args:
fwd (float): the item for the forward pass
bwd (float): the item for the backward pass
"""
fwd: Any
bwd: Any
total: Any
@dataclass
class MemoryCost:
"""
MemoryCost is a dataclass which stores the memory usage in the program.
Args:
activation (int): the memory cost incurred by the activations in bytes.
parameter (int): the memory cost incurred by the module parameter in bytes.
temp (int): the memory cost incurred by the temporary tensors in bytes.
buffer (int): the memory cost incurred by the module buffer in bytes.
"""
activation: int = 0
parameter: int = 0
temp: int = 0
buffer: int = 0
class CommType(Enum):
"""
CommType describes the sequential order of a communication action and a computation action.
Meaning:
BEFORE: the communication action happens just before the computation operation.
AFTER: the communication action happens after the computation operation.
HOOK: the communication action is used to do the grad all reduce.
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
"""
BEFORE = 0
AFTER = 1
HOOK = 2
IMPLICIT = 3
@dataclass
class CommAction:
"""
CommAction is used to record the communication action.
Args:
comm_spec: express the communication pattern and the process groups to execute the communication action.
comm_type: describes the sequential order of a communication action and a computation action.
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
because the args of node may be changed by graph transform passes.
"""
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
key_for_kwarg: any = None
@dataclass
class ShardingStrategy:
"""
ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node.
Args:
name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
output_sharding_spec (ShardingSpec): ShardingSpec of the output node.
compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None)
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
"""
name: str
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
compute_cost: TrainCycleItem = None
communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None
communication_actions: Dict[OperationData, CommAction] = None
resharding_costs: Dict[Node, List[TrainCycleItem]] = None
@property
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
specs = {}
specs.update(self._get_sharding_spec(OperationDataType.ARG))
specs.update(self._get_sharding_spec(OperationDataType.PARAM))
return specs
@property
def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.ARG)
@property
def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.PARAM)
@property
def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.OUTPUT)
def _get_sharding_spec(self, operation_data_type: OperationDataType):
specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type}
return specs
def get_op_data_by_name(self, name: str):
for op_data in self.sharding_specs.keys():
if op_data.name == name:
return op_data
raise KeyError(f"Could not find the OperationData with name {name}")
def get_sharding_spec_by_name(self, name: str):
for op_data, sharding_spec in self.sharding_specs.items():
if op_data.name == name:
return sharding_spec
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
def clone(self):
def _deepcopy_dict_vals(data: Dict):
return {k: deepcopy(v) for k, v in data.items()}
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs is not None else None
# We need to deepcopy it when self.communication_actions is not None, instead of checking its __bool__ value.
# Consider the examples below:
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
communication_actions = _deepcopy_dict_vals(
self.communication_actions) if self.communication_actions is not None else None
# same reason as communication_actions
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
compute_cost = deepcopy(self.compute_cost)
communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost)
return ShardingStrategy(name=self.name,
sharding_specs=sharding_specs,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
communication_actions=communication_actions,
resharding_costs=resharding_costs)
class StrategiesVector(list):
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node (Node): node for which the list of sharding strategies are generated.
'''
def __init__(self, node: Node):
super().__init__()
self.node = node
# fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys())
self.successor_nodes = list(node.users.keys())
def check_merge(self):
merge_label = False
if self.node.op == 'call_module':
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
submod_type = type(submod)
# merge elementwise module node into source nodes
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
if self.node.op == 'call_function':
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
# TODO: remove this after we support the fall back logic.
# if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
# merge_label = True
# we could merge reshape op, because their computation costs are negligible.
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
if self.node.op == 'call_method':
# we could merge reshape op, because their computation costs are negligible.
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
if method in RESHAPE_METHOD_OP:
merge_label = True
if method in ELEMENTWISE_METHOD_OP:
merge_label = True
return merge_label
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
import torch
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
class CostGraph:
'''
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes.
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
'''
def __init__(self, leaf_strategies, simplify=True, forward_only=False):
self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
# stores number of strategies in each node
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes
self.extra_node_costs = {}
self.following_dict = {}
self.simplify = simplify
self.forward_only = forward_only
self._build_cost_graph()
def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, [])
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)
def _build_cost_graph(self):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
'''
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
for strategies_vector in self.leaf_strategies:
# build edge_cost
dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.strategies_vector)):
resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j]
if self.forward_only:
edge_cost[(j, i)] = resharding_cost_item.fwd
else:
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
parent_nodes = []
children_nodes = []
def _check_tensor_in_node(data):
"""
This method is used to check whether the data has a tensor inside or not.
"""
has_tensor_flag = False
if isinstance(data, torch.Tensor):
return True
elif isinstance(data, (tuple, list)):
for d in data:
has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d)
return has_tensor_flag
for node in strategies_vector.predecessor_nodes:
if _check_tensor_in_node(node._meta_data):
parent_nodes.append(node)
for node in strategies_vector.successor_nodes:
if _check_tensor_in_node(node._meta_data):
children_nodes.append(node)
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
# we only merge node pairs which src node has a tensor element inside.
# This is necessay because the node without a tensor element inside will not
# be assigned any strategy.
if _check_tensor_in_node(followed_node._meta_data):
self.merge_pair.append((followed_node, dst_node))
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
'''
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy.
3. Build connections between new node pairs, and remove the src_node after all consumer nodes
detached from it.
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
# build merge_map
merge_map = {}
for src_index, _ in enumerate(src_node.strategies_vector):
min_cost = INFINITY_COST
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost_item = dst_strategy.resharding_costs[src_node][src_index]
if self.forward_only:
resharding_cost = resharding_cost_item.fwd
else:
resharding_cost = resharding_cost_item.total
if resharding_cost <= min_cost:
min_cost = resharding_cost
lowest_cost_index = dst_index
merge_map[src_index] = lowest_cost_index
# extra_node_cost for src node
self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
for src_index, strategy in enumerate(src_node.strategies_vector):
target_strate_index = merge_map[src_index]
target_strategy = dst_node.strategies_vector[target_strate_index]
resharding_cost_item = target_strategy.resharding_costs[src_node][src_index]
if self.forward_only:
resharding_cost_to_add = resharding_cost_item.fwd
else:
resharding_cost_to_add = resharding_cost_item.total
self.extra_node_costs[src_node][src_index] += resharding_cost_to_add
if dst_node in self.extra_node_costs:
self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
# add new node pair to cost graph
for child_node in dst_node.children:
new_node_pair = (src_node, child_node)
old_node_pair = (dst_node, child_node)
if new_node_pair in self.edge_costs:
continue
edge_cost = {}
for i in range(self.node_lens[src_node]):
for j in range(self.node_lens[child_node]):
dst_strate_index = merge_map[i]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
if new_node_pair not in self.edge_costs:
self.edge_costs[new_node_pair] = edge_cost
else:
# we should accumulate the resharding costs if args of child node contain
# both src node and dst node.
for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
# connect src node and children of dst node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
self.edge_costs.pop((src_node, dst_node))
for child_node in dst_node.children:
if child_node not in src_node.children:
src_node.children.append(child_node)
if src_node not in child_node.parents:
child_node.parents.append(src_node)
# remove dst node from cost graph when dst node has no producer.
if len(dst_node.parents) == 0:
child_node.parents.remove(dst_node)
node_pair = (dst_node, child_node)
self.edge_costs.pop(node_pair)
if len(dst_node.parents) == 0:
self.following_dict[dst_node] = src_node
dst_node.children = []
def _reindexing_src(self, src):
if src not in self.following_dict:
return src
return self._reindexing_src(self.following_dict[src])
def simplify_graph(self):
if not self.simplify:
return
self.merge_pair.reverse()
for (src_node, dst_node) in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
for dst, src in self.following_dict.items():
reindexing_following_dict[dst] = self._reindexing_src(src)
self.following_dict = reindexing_following_dict
from dataclasses import dataclass
from typing import List
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
@dataclass
class LiveVariable:
"""
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
name: str
node: Node
is_inplace: bool
class LiveVariableVector(list):
"""
LiveVariableVector is a data structure to store the list of LiveVariable objects.
"""
def exists(self, name) -> bool:
"""
Check if a variable has already existed in the current list by name.
"""
for var in self:
if name == var.name:
return True
return False
def get(self, name) -> LiveVariable:
for var in self:
if name == var.name:
return var
raise KeyError(f"Variable {name} is not found")
def copy(self) -> "LiveVariableVector":
"""
Create a copy of this vector
"""
vector = LiveVariableVector()
for var in self:
vector.append(var)
return vector
@dataclass
class LiveStage:
"""
LiveStage is a data structure to record the living variables at this current node.
"""
name: str
node: Node
all_live_vars: LiveVariableVector
unique_live_vars: LiveVariableVector
class GraphAnalyser:
def __init__(self, gm: GraphModule):
self._gm = gm
self._graph = gm.graph
@property
def gm(self) -> GraphModule:
"""
Return the GraphModule object associated with this analyser.
"""
return self._gm
@property
def graph(self) -> Graph:
"""
Return the Graph object associated with this analyser.
"""
return self._graph
def liveness_analysis(self) -> List[LiveStage]:
"""
Analyse the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
"""
compute_nodes = self.graph.nodes
liveness_list = []
# checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage.
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
# unique: record the unique live variables only exist until the current stage.
# this is different from `all list` as some variables are duplicated.
checked_variables = LiveVariableVector()
all_live_variables = LiveVariableVector()
unique_live_vars = LiveVariableVector()
for idx, node in enumerate(compute_nodes):
#############################
# find new living variables #
#############################
# detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplciate var
is_inplace = False
if node.op == 'call_function':
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
if node.kwargs.get('inplace', False):
is_inplace = True
elif node.op == 'call_module':
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node)
if getattr(module, 'inplace', False):
is_inplace = True
# add the output var
meta = getattr(node, '_meta_data', None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
checked_variables.append(live_var)
all_live_variables.append(live_var)
# check if any input is not checked yet
for arg in node.args:
if not isinstance(arg, Node):
continue
arg_name = arg.name
if not checked_variables.exists(arg_name):
live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
all_live_variables.append(live_var_from_arg)
checked_variables.append(live_var_from_arg)
unique_live_vars.append(live_var_from_arg)
# TODO: add the logic to remove live variables
# this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict
stage = LiveStage(name=node.name,
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy())
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):
all_covered = True
for ele in prev_stage.unique_live_vars:
if ele not in stage.unique_live_vars:
all_covered = False
break
if all_covered:
replace = True
break
if replace:
liveness_list[index] = stage
else:
liveness_list.append(stage)
return liveness_list
def get_alias_set(self):
pass
"""This code is adapted from Alpa
https://github.com/alpa-projects/alpa/
with some changes. """
import multiprocessing
import time
import warnings
from typing import Dict
import numpy as np
from torch.fx.graph import Graph
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .strategies_constructor import StrategiesConstructor
try:
import pulp
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
warnings.warn(f'please install the pulp')
__all___ = ['Solver']
class Solver:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3,
verbose=False):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
self.forward_only = forward_only
if self.solution_numbers > 1:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# self.liveness_list = self.graph_analyser.liveness_analysis()
self.liveness_list = self.nodes
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
self.verbose = verbose
def _recover_merged_node_strategy(self):
'''
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
input_best_strategy_index = self.last_s_val[node_index - 1]
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
for strategy_index, strategy in enumerate(node.strategies_vector):
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
self.last_s_val[node_index] = strategy_index
break
def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
node_index_dict[strategies_vector.node] = index
return node_index_dict
def _prepare_data_for_solver(self):
'''
Extract information from components for solver.
'''
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
# prepare strategies_len
strategies_len = []
for node in self.nodes:
strategies_len.append(self.cost_graph.node_lens[node])
strategies_len = np.array(strategies_len)
# prepare following_nodes
following_nodes = self.cost_graph.following_dict
index_following_nodes = {}
for src, target in following_nodes.items():
src_index = self.node_index_dict[src]
target_index = self.node_index_dict[target]
index_following_nodes[src_index] = target_index
following_nodes = index_following_nodes
for index in range(node_nums):
if index not in following_nodes:
following_nodes[index] = -1
# prepare edge_pairs and resharding costs
edge_pairs = []
resharding_costs = []
for pairs, edge_cost in self.cost_graph.edge_costs.items():
src_node = pairs[0]
dst_node = pairs[1]
src_node_index = self.node_index_dict[src_node]
dst_node_index = self.node_index_dict[dst_node]
edge_pairs.append(src_node_index)
edge_pairs.append(dst_node_index)
for i in range(strategies_len[src_node_index]):
for j in range(strategies_len[dst_node_index]):
resharding_costs.append(edge_cost[(i, j)])
edge_pairs = np.array(edge_pairs)
resharding_costs = np.array(resharding_costs)
# prepare liveness_set
liveness_set = self.liveness_list
# omit alias_set now
alias_set = self.strategies_constructor.alias_set
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
compute_costs = []
communication_costs = []
memory_costs = []
extra_node_costs = self.cost_graph.extra_node_costs
for strategies_vector in self.leaf_strategies:
node = strategies_vector.node
for index, strategy in enumerate(strategies_vector):
compute_cost_item = strategy.compute_cost
communication_cost_item = strategy.communication_cost
memory_cost_item = strategy.memory_cost
if self.forward_only:
origin_communication_cost = communication_cost_item.fwd
compute_cost = compute_cost_item.fwd
# extract MemoryCost item from the memory TrainCycleItem
memory_cost = memory_cost_item.fwd
else:
origin_communication_cost = communication_cost_item.total
compute_cost = compute_cost_item.total
# extract MemoryCost item from the memory TrainCycleItem
memory_cost = memory_cost_item.total
# extract the memory cost in float from MemoryCost item and sum them up
memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer
compute_costs.append(compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
# cost into
if node in extra_node_costs:
extra_node_cost = extra_node_costs[node][index]
communication_cost = origin_communication_cost + extra_node_cost
communication_costs.append(communication_cost)
else:
communication_costs.append(origin_communication_cost)
memory_costs.append(memory_cost)
compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array(memory_costs)
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None,
verbose=True):
"""
Call the solver with serialized arguments.
"""
tic = time.time()
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
assert isinstance(x, np.ndarray)
assert len(strategies_len) == node_nums, "strategies_len"
def get_non_zero_index(binary_vector):
"""
Get the index of non-zero item in a vector.
"""
ct = 0
ret = None
for i, elem in enumerate(binary_vector):
if pulp.value(elem):
ret = i
ct += 1
assert ct == 1
return ret
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
s_alias = alias_set
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
######################
# omit alias set now #
######################
# A = alias_set.reshape((-1, 2)) # noqa
# for (i, j) in A:
# prod_length = strategies_len[i] * strategies_len[j]
# v.append(alias_convert_costs[pt:pt + prod_length])
# pt += prod_length
# assert pt == len(alias_convert_costs)
# L = [] # noqa
# pt = node_nums
# for i in range(node_nums):
# length = liveness_set[i]
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
d = []
m = []
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
# 1. Create variables
#############################
# create variables for node #
#############################
s = []
num_nodes = 0
reverse_follow_backpatch = []
for i in range(node_nums):
if s_follow[i] < 0:
if strategies_len[i] == 1:
s.append([1])
else:
if i not in s_alias:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
s.append(s[s_alias[i]])
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
else:
s.append(None)
reverse_follow_backpatch.append(i)
for i in reverse_follow_backpatch:
s[i] = s[s_follow[i]]
#############################
# create variables for edge #
#############################
e = []
num_edges = 0
map_edge_to_idx = {}
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx:
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
map_edge_to_idx[(i, j)] = idx
for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
s[idx][i].fixValue()
# 3. Objective
prob = LpProblem("myProblem", LpMinimize)
###################################################################
# computing the node cost(computing cost and communication cost) #
###################################################################
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i])
prob += obj
# 4. Constraints
# (a). specified by `cat="Binary"`
# (b)
#################################################
# make sure each node only choose one strategy #
#################################################
for i in range(node_nums):
if s_follow[i] < 0:
prob += lpSum(s[i]) == 1
# (c)
#################################################
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
mem = 0
for node in liveness_set:
if node not in self.node_index_dict:
continue
node_index = self.node_index_dict[node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
# (e)
prob += lpSum(e[idx]) == 1
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
######################
# omit alias set now #
######################
# alias_set = set()
# for (idx, (i, j)) in enumerate(A):
# R = len(s[i]) # noqa
# C = len(s[j]) # noqa
# if (i, j) in alias_set:
# raise ValueError(f"Duplicated edges: {(i, j)}")
# alias_set.add((i, j))
# alias_set.add((j, i))
# for row in range(len(s[i])):
# for col in range(len(s[j])):
# if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
prob.solve(solver)
status = prob.status
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
for i in range(node_nums):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = list(s_val)
# self._recover_merged_node_strategy()
self.last_objective = objective
if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
return self.last_s_val, e_val, self.last_objective, status
def call_solver_serialized_args(self):
"""
Call the solver with serialized arguments and handle python errors. Additionally,
we could give a serious of solutions with different memory budget.
"""
if self.solution_numbers == 1:
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
return ret
origin_memory_budget = self.memory_budget
memory_budget_list = [
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
]
ret_list = []
for memory_budget in memory_budget_list:
self.memory_budget = memory_budget
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
ret_list.append(ret)
return ret_list
import builtins
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx import Graph, Node
from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler,
OutputHandler,
PlaceholderHandler,
operator_registry,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh
from ..options import DataloaderOption, SolverOptions
__all__ = ['StrategiesConstructor']
class StrategiesConstructor:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.solver_options = solver_options
self.no_strategy_nodes = []
self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
def generate_alias_set(self):
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
repeat_block_nums = len(common_blocks)
alias_set = {}
if repeat_block_nums == 0:
return alias_set
for index, common_node in enumerate(common_blocks[0]):
for i in range(1, repeat_block_nums):
alias_set[node_list.index(common_blocks[i][index])] = node_list.index(common_node)
return alias_set
def build_strategies_and_cost(self):
"""
This method is to build the strategy vector for each node in the computation graph.
"""
def _check_no_strategy_for_node(node):
if node.op in ('placeholder', 'get_attr', 'output'):
return False
def _check_no_strategy_for_data(data):
label = True
if isinstance(data, torch.Tensor):
return False
elif isinstance(data, (tuple, list)):
for d in data:
label = label and _check_no_strategy_for_data(d)
return label
return _check_no_strategy_for_data(node._meta_data)
for node in self.nodes:
strategies_vector = StrategiesVector(node)
if _check_no_strategy_for_node(node):
self.no_strategy_nodes.append(node)
pass
# placeholder node
elif node.op == 'placeholder':
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
placeholder_option = 'distributed'
else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
placeholder_option = 'replicated'
placeholder_handler = PlaceholderHandler(node,
self.device_mesh,
strategies_vector,
placeholder_option=placeholder_option)
placeholder_handler.register_strategy()
# get_attr node
elif node.op == 'get_attr':
getattr_handler = GetattrHandler(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
getattr_handler.register_strategy()
# call_module node
elif node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
handler = operator_registry.get(submod_type)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach strategies_info to node
if hasattr(handler, 'strategies_info'):
setattr(node, 'strategies_info', handler.strategies_info)
# call_function node
elif node.op == 'call_function':
target = node.target
handler = operator_registry.get(target)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach strategies_info to node
if hasattr(handler, 'strategies_info'):
setattr(node, 'strategies_info', handler.strategies_info)
# call_method node
elif node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach strategies_info to node
if hasattr(handler, 'strategies_info'):
setattr(node, 'strategies_info', handler.strategies_info)
# output node
elif node.op == 'output':
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
output_option = 'distributed'
else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
output_option = 'replicated'
output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy()
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)
alias_set = self.generate_alias_set()
self.alias_set = alias_set
from .broadcast import (
BroadcastType,
comm_actions_for_oprands,
get_broadcast_shape,
is_broadcastable,
recover_sharding_spec_for_broadcast_shape,
)
from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
from .reshape import check_keep_sharding_status, detect_reshape_mapping, infer_output_dim_partition_dict
from .sharding import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size,
transpose_partition_dim,
update_partition_dim,
)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map',
'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict'
]
from enum import Enum, auto
from typing import List
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
)
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
'comm_actions_for_oprands'
]
class BroadcastType(Enum):
EQUAL = auto()
PADDDING = auto()
MULTIPLE = auto()
def is_broadcastable(shape1: torch.Size, shape2: torch.Size) -> bool:
"""
Check if two shapes are broadcastable to each other.
"""
for s1, s2 in zip(shape1[::-1], shape2[::-1]):
if s1 == 1 or s2 == 1 or s1 == s2:
pass
else:
return False
return True
def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
"""
Compute the broadcast shape given two shapes.
"""
assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
shape1_reverse = shape1[::-1]
shape2_reverse = shape2[::-1]
min_common_dim = min(len(shape1), len(shape2))
dims = []
for s1, s2 in zip(shape1_reverse, shape2_reverse):
dims.append(max(s1, s2))
# append the remaining dims
dims.extend(shape1_reverse[min_common_dim:])
dims.extend(shape2_reverse[min_common_dim:])
return dims[::-1]
def get_broadcast_dim_info(logical_shape, physical_shape):
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
assert logical_num_dims >= physical_num_dims, \
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
# track the dim and its broadcasting type
logical_dim_broadcast_info = {}
for i in range(logical_num_dims):
# get the trailing dim size
logical_dim_idx = logical_num_dims - i - 1
phyiscal_dim_idx = physical_num_dims - i - 1
logical_dim_size = logical_shape[logical_dim_idx]
if phyiscal_dim_idx >= 0:
physical_dim_size = physical_shape[phyiscal_dim_idx]
if physical_dim_size == logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE
else:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
return logical_dim_broadcast_info
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
# recording the sharding dimensions removed during logical shape converting to physical one
removed_dims = []
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec, removed_dims
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
# get the broadcast info
logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape)
# generate the sharding spec for the physical shape
physical_dim_partition = {}
logical_dim_partition = logical_sharding_spec.dim_partition_dict
for shape_dim, mesh_dim in logical_dim_partition.items():
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
removed_dims.extend(mesh_dim)
else:
# get the corresponding physical dim
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
physical_dim_partition[physical_dim] = mesh_dim
physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition)
return physical_sharding_spec, removed_dims
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
sharding_spec: ShardingSpec) -> CommAction:
"""
This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape.
"""
if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0]
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
sharding_spec=sharding_spec,
logical_process_axis=removed_dims)
if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK
else:
comm_type = CommType.BEFORE
arg_index = -1
for index, arg in enumerate(node.args):
if op_data.name == str(arg):
arg_index = index
assert arg_index >= 0, f'op_data should be an argument of node.'
comm_action = CommAction(
comm_spec=comm_spec,
comm_type=comm_type,
arg_index=arg_index,
)
return comm_action
import copy
import operator
import warnings
from functools import reduce
from typing import Dict, List, Optional, Union
import torch
from torch.fx.node import Node
from torch.utils._pytree import tree_map
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from ..constants import INFINITY_COST
__all__ = ['generate_sharding_spec', 'generate_resharding_costs']
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
Args:
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
if isinstance(input_, Node):
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
elif isinstance(input_, torch.Tensor):
shape = input_.shape
else:
raise TypeError(
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
)
for dim_index, sharding_index_list in dim_partition_dict.items():
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
assert shape[
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec
def generate_resharding_costs(nodes: List[Node],
sharding_specs: List[ShardingSpec],
count_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None,
index=None):
'''
Compute the resharding costs with this specific strategy.
Argument:
nodes (List[Node]): a list of nodes
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
if not isinstance(input_sharding_spec, ShardingSpec):
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
input_sharding_spec = input_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
try:
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
except AssertionError as e:
warnings.warn(f'{e}')
resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20):
'''
Find the largest repeat blocks in the graph, whose length is larger than the threshold.
Args:
gm (GraphModule): the graph module to be analyzed.
common_length_threshold (int): the threshold of the repeat block length.
'''
# graph = gm.graph
def _process_args(args):
new_args = []
for arg in args:
if hasattr(arg, '_meta_data'):
meta_data = arg._meta_data
else:
meta_data = arg
def _process_arg(data):
if isinstance(data, torch.Tensor):
data = data.size()
elif isinstance(data, slice):
data = (data.start, data.step, data.stop)
return data
new_meta_data = tree_map(_process_arg, meta_data)
new_args.append(new_meta_data)
return new_args
def _all_equal(check_list, check_fn):
base_value = check_list[-1]
for e in check_list:
if not check_fn(e, base_value):
return False
return True
def _check_node_list_equal(l1, l2):
if len(l1) != len(l2):
return False
for node1, node2 in zip(l1, l2):
if hash(node1.hash_key) != hash(node2.hash_key):
return False
return True
def _check_node_equal(node1, node2):
if hash(node1.hash_key) == hash(node2.hash_key):
return True
return False
for index, node in enumerate(node_list):
if node.op == 'call_module':
target = node.target
submod = root_module.get_submodule(target)
submod_type = type(submod)
target = submod_type
else:
target = node.target
new_args = _process_args(node.args)
if node.op != 'get_attr':
hash_key = (node.op, target, *new_args)
else:
hash_key = (node.op,)
setattr(node, 'hash_key', hash_key)
hash_value_to_node_dict = {}
for index, node in enumerate(node_list):
hash_value = hash(node.hash_key)
if hash_value not in hash_value_to_node_dict:
hash_value_to_node_dict[hash_value] = []
hash_value_to_node_dict[hash_value].append(index)
# node_list = list(graph.nodes)
node_list_start = 0
max_common_length = common_length_threshold
common_blocks_index = []
for index, node in enumerate(node_list):
# the comparison will be triggered if a common node appears
if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2:
start_index_list = hash_value_to_node_dict[hash(node.hash_key)]
check_block_list = [node_list[start:start + max_common_length] for start in start_index_list]
common_label = True
if not _all_equal(check_block_list, _check_node_list_equal):
common_label = False
if common_label:
common_blocks_index = copy.deepcopy(start_index_list)
max_step = len(node_list) - common_blocks_index[-1] - max_common_length - 1
for i in range(max_step):
# add assertion to avoid out of index
next_node_list = [node_list[index + max_common_length + i] for index in start_index_list]
if not _all_equal(next_node_list, _check_node_equal):
max_step = i
break
max_common_length += max_step
node_list_start += max_common_length
# recover common subgraph from the index
common_blocks = []
for start in common_blocks_index:
common_blocks.append(node_list[start:start + max_common_length])
return common_blocks
import functools
from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
__all__ = ['ignore_sharding_exception', 'pytree_map']
def ignore_sharding_exception(func):
"""
A function wrapper to handle the ShardingSpecException in the function.
If ShardingSpecException occurs, this function will return None.
Usage:
# mute the assertion error in the function
@ignore_sharding_exception
def do_something():
...
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
logger = get_dist_logger()
rst = func(*args, **kwargs)
return rst
except ShardingSpecException as e:
logger.debug(e)
return None
return wrapper
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 3 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
3. the sharding spec's entire shape must match the tensor shape
#
"""
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'):
devices_str = str(dim_spec).lstrip('S')
num_devices = 1
if '0' in devices_str:
num_devices *= num_devices_in_col
if '1' in devices_str:
num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
# make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape, \
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
"""process object recursively, like pytree
Args:
obj (:class:`Any`): object to process
fn (:class:`Callable`): a function to process subobject in obj
process_types (:class: `type | tuple[type]`): types to determine the type to process
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
Returns:
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
"""
if isinstance(obj, dict):
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
elif isinstance(obj, tuple):
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, list):
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, process_types):
return fn(obj)
else:
return fn(obj) if map_all else obj
from enum import Enum
from typing import Dict, List, Tuple
import torch
class PreviousStatus(Enum):
"""
This class shows the status of previous comparision.
"""
RESET = 0
# ORIGIN means the dimension size of original tensor is larger in the previous comparision.
ORIGIN = 1
# TGT means the dimension size of target tensor is larger in the previous comparision.
TGT = 2
def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> Dict[Tuple[int], Tuple[int]]:
"""
This method is used to detect the reshape mapping between original tensor and target tensor.
Returns:
reshape_mapping_dict: The dictionary shows how a tuple of origin dims(keys) mapping to the related
target dims(values) during reshaping operation.
Examples:
import torch
origin_shape = torch.Size([4, 4, 4])
tgt_shape = torch.Size([2, 8, 2, 2])
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
print(reshape_mapping_dict)
Output:
{(2,): (3, 2), (1, 0): (1,), (0,): (0, 1)}
"""
# reverse the shape object
origin_shape = list(origin_shape)
tgt_shape = list(tgt_shape)
origin_shape.reverse()
tgt_shape.reverse()
# initialize arguments
reshape_mapping_dict = {}
origin_len = len(origin_shape)
tgt_len = len(tgt_shape)
origin_index = 0
tgt_index = 0
original_dimension_size = origin_shape[origin_index]
tgt_dimension_size = tgt_shape[tgt_index]
tgt_dims = [tgt_len - tgt_index - 1]
origin_dims = [origin_len - origin_index - 1]
previous_label = PreviousStatus.RESET
while origin_index != len(origin_shape) or tgt_index != len(tgt_shape):
if original_dimension_size == tgt_dimension_size:
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
# if the origin_dims has no element, it means the original tensor has been fully matched.
# Therefore, we do not have to increase the origin_index for that case.
if len(origin_dims) > 0:
origin_index += 1
# if the tgt_dims has no element, it means the original tensor has been fully matched.
# Therefore, we do not have to increase the tgt_index for that case.
if len(tgt_dims) > 0:
tgt_index += 1
# the last step of loop should always end with condition
# so we need to manually skip the preparation for next step
# in the last step.
if origin_index == len(origin_shape) and tgt_index == len(tgt_shape):
continue
# If origin_index equals to origin_len, we just need to set the original_dimension_size
# to 1 to match the remaining '1's in the target tensor shape.
if origin_index == len(origin_shape):
original_dimension_size = 1
origin_dims = []
else:
original_dimension_size = origin_shape[origin_index]
origin_dims = [origin_len - origin_index - 1]
# If tgt_index equals to tgt_len, we just need to set the tgt_dimension_size
# to 1 to match the remaining '1's in the original tensor shape.
if tgt_index == len(tgt_shape):
tgt_dimension_size = 1
tgt_dims = []
else:
tgt_dimension_size = tgt_shape[tgt_index]
tgt_dims = [tgt_len - tgt_index - 1]
previous_label = PreviousStatus.RESET
elif original_dimension_size > tgt_dimension_size:
tgt_index += 1
if previous_label == PreviousStatus.TGT:
# if the target dimension size is larger in the previous comparision, which means
# the origin dimension size has already accumulated larger than target dimension size, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
original_dimension_size = original_dimension_size // tgt_dimension_size
origin_dims = [origin_len - origin_index - 1]
tgt_dimension_size = tgt_shape[tgt_index]
tgt_dims = [tgt_len - tgt_index - 1, tgt_len - tgt_index]
# reset the previous_label after offloading the origin dims and tgt dims
previous_label = PreviousStatus.RESET
else:
# accumulate the tgt_dimension_size until tgt_dimension_size larger than original_dimension_size
tgt_dimension_size *= tgt_shape[tgt_index]
tgt_dims.append(tgt_len - tgt_index - 1)
previous_label = PreviousStatus.ORIGIN
else:
origin_index += 1
if previous_label == PreviousStatus.ORIGIN:
# if the origin element is larger in the previous comparision, which means
# the target element has already accumulated larger than origin element, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
tgt_dimension_size = tgt_dimension_size // original_dimension_size
tgt_dims = [tgt_len - tgt_index - 1]
original_dimension_size = origin_shape[origin_index]
origin_dims = [origin_len - origin_index - 1, origin_len - origin_index]
# reset the previous_label after offloading the origin dims and tgt dims
previous_label = PreviousStatus.RESET
else:
# accumulate the original_dimension_size until original_dimension_size larger than tgt_dimension_size
original_dimension_size *= origin_shape[origin_index]
origin_dims.append(origin_len - origin_index - 1)
previous_label = PreviousStatus.TGT
return reshape_mapping_dict
def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool:
"""
This method is used to check whether the reshape operation could implement without converting
the input to fully replicated status.
Rule:
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
the function will return false.
To illustrate this issue, there are two cases to analyse:
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
operation without distributed tensor.
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape
consistency process, torch.cat will be implemented on the sharded dim, and everything after the sharded
dim get recovered.
Examples:
# the second dimension of the input has been sharded.
input_dim_partition_dict = {1: [1]}
origin_shape = torch.Size([8, 4, 2])
tgt_shape = torch.Size([2, 4, 8])
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
# {(2, 1): (2,), (0,): (1, 0)}
# the sharded dim of input is 1, which is the minimum element of the tuple (2, 1),
# so we do not have to convert the input to fully replicated status.
print(check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict))
Output:
True
"""
sharded_dims = list(input_dim_partition_dict.keys())
for input_dims in reshape_mapping_dict.keys():
# if input_dims has no element, we could just skip this iteration.
if len(input_dims) == 0:
continue
min_element = min(input_dims)
for dim in input_dims:
if dim in sharded_dims and dim is not min_element:
return False
return True
def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]],
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]:
"""
This method is used to infer the output dim partition dict for a reshape operation,
given the input dim partition dict and reshape mapping dict.
"""
assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \
'we only infer output dim partition dict for the reshape operation could keep sharding spec.'
sharded_dims = list(input_dim_partition_dict.keys())
output_dim_partition_dict = {}
for input_dims, output_dims in reshape_mapping_dict.items():
for dim in input_dims:
if dim in sharded_dims:
output_dim_partition_dict[min(output_dims)] = input_dim_partition_dict[dim]
# we could break because input dims cannot contain two sharded dims, otherwise
# the keep sharding status check will fail.
break
return output_dim_partition_dict
import operator
from copy import deepcopy
from functools import reduce
from typing import Dict
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]
def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
"""
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
Args:
sharding_spec (ShardingSpec): the sharding spec for which partition dim are switched
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
assert len(sharding_spec.entire_shape) >= 2, \
'The entire_shape of the sharding spec must have at least 2 dimensions'
dim_partition_dict = sharding_spec.dim_partition_dict
# transpose the dim partition
dim1_partition = dim_partition_dict.pop(dim1, None)
dim2_partition = dim_partition_dict.pop(dim2, None)
if dim1_partition:
dim_partition_dict[dim2] = dim1_partition
if dim2_partition:
dim_partition_dict[dim1] = dim2_partition
# get the transposed shape
new_shape = list(sharding_spec.entire_shape[:])
new_shape[dim2], new_shape[dim1] = new_shape[dim1], new_shape[dim2]
new_shape = torch.Size(new_shape)
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, new_shape, dim_partition_dict)
return sharding_spec
def update_partition_dim(sharding_spec: ShardingSpec,
dim_mapping: Dict[int, int],
physical_shape: torch.Size,
inplace: bool = False):
"""
This method is used to update the partition dim dict from the logical one to the physical one.
Args:
sharding_spec (ShardingSpec): the sharding spec for which partition dims are updated
dim_mapping (Dict[int, int]): the mapping from the logical tensor dimension to the physical tensor dimension
physical_shape (torch.Size): the physical shape for the tensor
"""
if inplace:
current_sharding_spec = sharding_spec
else:
current_sharding_spec = deepcopy(sharding_spec)
old_dim_partition_dict = current_sharding_spec.dim_partition_dict
new_dim_partition_dict = {}
# assign new dim
for old_dim, new_dim in dim_mapping.items():
mesh_dims = old_dim_partition_dict.pop(old_dim)
new_dim_partition_dict[new_dim] = mesh_dims
for tensor_dim, mesh_dims in old_dim_partition_dict.items():
if tensor_dim in new_dim_partition_dict:
raise KeyError(f"There are duplicated entries for the tensor sharding dimension {tensor_dim}")
else:
new_dim_partition_dict[tensor_dim] = mesh_dims
# update sharding spec
current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=new_dim_partition_dict)
return current_sharding_spec
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
dim_partition_list = []
# enumerate all the 2D sharding cases
for i in range(dim_size):
for j in range(i + 1, dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
dim_partition_list.append(dim_partition_dict_1)
for i in range(dim_size):
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
dim_partition_list.append(dim_partition_dict_flatten)
return dim_partition_list
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
dim_partition_list = []
# enumerate all the 1D sharding cases
for i in range(dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
return dim_partition_list
def generate_sharding_size(dim_partition_dict, device_mesh):
total_sharding_size = 1
for mesh_dim_list in dim_partition_dict.values():
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
total_sharding_size *= sharding_size
return total_sharding_size
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
import colossalai
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, get_logger, get_node_name, get_node_shape
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
"""
Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :]
Args:
chunk_dim (int)
chunk_indice_name (str): chunk indice name
shape (List): node shape
Returns:
new_shape (str): return slice
"""
new_shape = "["
for idx, _ in enumerate(shape):
if idx == chunk_dim:
new_shape += "%s:%s + chunk_size" % (chunk_indice_name, chunk_indice_name)
else:
new_shape += ":"
new_shape += ", "
new_shape = new_shape[:-2] + "]"
return new_shape
def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str:
"""
Generate chunk loop start
eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device)
chunk_size = 32
for chunk_idx in range(0, 100, 32):
......
Args:
chunk_input (List[Node]): chunk input node
chunk_output (Node): chunk output node
chunk_ouput_dim (int): chunk output node chunk dim
chunk_size (int): chunk size. Defaults to 2.
Returns:
context (str): generated str
"""
input_node = chunk_input[0]
context = ""
for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
input_node.name)
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else:
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
input_node.name, input_node.name)
out_shape = get_node_shape(chunk_output[0])
chunk_shape = out_shape[chunk_ouput_dim[0]]
context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape)
return context
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
"""
Generate chunk loop end
eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node
output_node = chunk_result; xx = None; xx = None
Args:
chunk_inputs (List[Node]): chunk input node
chunk_non_compute_inputs (List[Node]): input node without chunk
chunk_outputs (Node): chunk output node
chunk_outputs_dim (int): chunk output node chunk dim
node_list (List)
Returns:
context (str): generated str
"""
context = "chunk_size = None"
# determine if its the last use for chunk input
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
if all([search_chunk.node_mgr.find_node_idx(user) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
context += "; %s = None" % chunk_input.name
for chunk_output_non_tensor, chunk_output_non_tensor_val in chunk_outputs_non_tensor.items():
context += "; %s = %s" % (chunk_output_non_tensor.name, chunk_output_non_tensor_val)
context += "\n"
return context
def _replace_name(context: str, name_from: str, name_to: str) -> str:
"""
replace node name
"""
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")]
for p in patterns:
source = p[0] + name_from + p[1]
target = p[0] + name_to + p[1]
if source in context:
context = context.replace(source, target)
break
return context
def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) -> str:
"""
replace reshape size, some may have changed due to chunk
"""
if node_name not in reshape_size_dict:
return context
context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1])
return context
def _replace_new_tensor_like_shape(
search_chunk: SearchChunk,
chunk_infos: List[Dict],
region_idx: int,
node_idx: int,
node: Node,
body: List[str],
) -> List[str]:
"""
add chunk slice for new tensor op such as ones like
"""
if get_node_name(node) in ["ones_like", "zeros_like", "empty_like"]:
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
return body
def _replace_new_tensor_shape(
search_chunk: SearchChunk,
chunk_infos: List[Dict],
region_idx: int,
node_idx: int,
node: Node,
body: List[str],
) -> List[str]:
"""
add chunk slice for new tensor op such as ones
"""
if get_node_name(node) in ["ones", "zeros", "empty"]:
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if chunk_dim is None:
return
if get_node_shape(meta_node)[chunk_dim] == 1:
return
origin_shape = str(node.args)
new_shape = list(node.args)
new_shape[chunk_dim] = "min(chunk_size, %d - chunk_idx)" % get_node_shape(meta_node)[chunk_dim]
new_shape = str(new_shape)
new_shape = new_shape.replace("'", "")
body[-1] = _replace_name(body[-1], origin_shape[1:-1], new_shape[1:-1])
return body
def _add_node_slice(
chunk_nodes: List[Node],
region_idx: int,
chunk_nodes_dim: Dict,
node_idx: int,
body: List[str],
node: Node,
) -> List[str]:
"""
add chunk slice for input nodes
"""
for chunk_node_idx, chunk_node in enumerate(chunk_nodes[region_idx]):
# inputs node
if isinstance(chunk_nodes_dim[region_idx][chunk_node_idx], dict):
for idx, dim in chunk_nodes_dim[region_idx][chunk_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(chunk_node))
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
# outputs node
else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
get_node_shape(chunk_node))
if get_node_name(chunk_node) in ["split", "unbind"]:
split_chunk_slice = ""
for i in range(len(chunk_node.meta['tensor_meta'])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
else:
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
return body
def emit_code_with_chunk(body: List[str],
nodes: Iterable[Node],
emit_node_func: Callable,
delete_unused_value_func: Callable,
search_chunk: SearchChunk,
chunk_infos: List,
eval_mem: bool = False):
"""
Emit code with chunk according to chunk_infos.
It will generate a for loop in chunk regions, and
replace inputs and outputs of regions with chunked variables.
Args:
body: forward code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
search_chunk: the class to search all chunks
chunk_infos: store all information about all chunks.
"""
node_list = list(nodes)
# chunk region
chunk_starts = [i["region"][0] for i in chunk_infos]
chunk_ends = [i["region"][1] for i in chunk_infos]
# chunk inputs
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs
chunk_outputs = [i["outputs"] for i in chunk_infos]
chunk_outputs_non_tensor = [i["outputs_non_tensor"] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
node_idx = 0
region_idx = 0
within_chunk_region = False
if eval_mem:
body.append("init_memory = torch.cuda.memory_allocated() / 1024**2\n")
while node_idx < len(node_list):
node = node_list[node_idx]
# if is chunk start, generate for loop start
if node_idx in chunk_starts:
within_chunk_region = True
region_idx = chunk_starts.index(node_idx)
body.append(
_gen_loop_start(
chunk_inputs[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
chunk_infos[region_idx]["chunk_size"],
))
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
# replace output var with chunk var
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
# new tensor like
body = _replace_new_tensor_like_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# new tensor
body = _replace_new_tensor_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# reassign reshape size
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)
if eval_mem:
body.append(
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, chunk_inputs_names)
if eval_mem:
body.append(
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
# generate chunk region end
if node_idx in chunk_ends:
body.append(
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
within_chunk_region = False
node_idx += 1
if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
def __init__(self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
eval_mem: bool = False) -> None:
super().__init__()
self.eval_mem = eval_mem
# find the chunk regions
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
self.chunk_infos = self.search_chunk.search_region()
if print_progress:
get_logger().info("AutoChunk start codegen")
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return _get_qualified_name(obj)
# normalize the name hint to get a proper identifier
global_name = namespace.create_name(name_hint, obj)
if global_name in globals_:
assert globals_[global_name] is obj
return global_name
globals_[global_name] = obj
return global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
# Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items():
add_global(name, obj)
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return "()"
typename = _type_repr(o)
if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
if len(args) == 0:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python < 3.9
return origin_typename
return f'{origin_typename}[{",".join(args)}]'
else:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+
return origin_typename
# Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
delete_free_var_from_last_use(user_to_last_uses)
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body, to_keep=[]):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == "placeholder":
return
if user.op == "output":
body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
if len(nodes_to_delete):
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
body.append(f"; {to_delete_str}\n")
else:
body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
body.append(f"{repr(node)} = {raw_name}\n")
return
elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})")
return
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
assert isinstance(node.args, tuple)
body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
and node.args[1].isidentifier() and len(node.args) == 2):
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
return
body.append(
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
return
elif node.op == "get_attr":
assert isinstance(node.target, str)
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
self.eval_mem)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append("pass\n")
if len(wrapped_fns) > 0:
wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
for name, value in self.additional_globals():
add_global(name, value)
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = "".join(ckpt_func) + prologue
prologue = prologue
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
{code}"""
# print(fn_code)
return PythonCode(fn_code, globals_)
import copy
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
from torch.fx.node import Node
from colossalai.fx.profiler import activation_size, parameter_size
from .utils import NodeMgr, get_node_shape, is_non_memory_node
class EstimateMemory(object):
"""
Estimate memory with chunk
"""
def __init__(self) -> None:
pass
def _get_node_size(self, x: Node) -> float:
"""
return node size in MB
"""
x = x.meta["tensor_meta"]
if not hasattr(x, "numel"):
out = sum([i.numel * torch.tensor([], dtype=i.dtype).element_size() for i in x])
else:
out = x.numel * torch.tensor([], dtype=x.dtype).element_size()
out = float(out) / 1024**2
return out
def _add_active_node(self, n: Node, active_nodes: Dict, chunk_ratio: float) -> None:
"""
add an active node and its shape to active node dict
"""
if get_node_shape(n) is None:
return
if n.op == "placeholder":
return
if n not in active_nodes:
node_size = self._get_node_size(n) * chunk_ratio
active_nodes[n] = node_size
def _build_delete_node_dict(self, node_mgr: NodeMgr) -> Dict:
"""
build delete node dict, means node should be deleted at what time
"""
delete_node_dict = {}
for idx, node in enumerate(node_mgr.get_node_list()):
# skip non shape node
if get_node_shape(node) is None:
continue
# dont remove free nodes
elif node.op == "placeholder":
delete_node_dict[node] = len(node_mgr.get_node_list())
# node no user
elif len(node.users) == 0:
delete_node_dict[node] = idx
# log max use
else:
node_user_idx = [node_mgr.find_node_idx(i) for i in node.users.keys()]
delete_node_dict[node] = max(node_user_idx)
return delete_node_dict
def _remove_deactive_node(self,
user_idx: int,
user: Node,
active_nodes: List,
delete_node_dict: List,
kept_nodes: List = None) -> None:
"""
remove deactivate nodes from active nodes
"""
if kept_nodes is None:
kept_nodes = []
if user.op in ("output",):
return
for node in list(active_nodes.keys()):
# dont delete kept nodes
if node in kept_nodes:
continue
# should be deleted
if delete_node_dict[node] <= user_idx:
active_nodes.pop(node)
def _get_tmp_memory(self, node, not_contiguous_list, delete=False):
mem = 0
not_contiguous_ops = ["permute"]
if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
for n in node.args:
if n in not_contiguous_list:
# matmul won't change origin tensor, but create a tmp copy
mem += self._get_node_size(n)
elif node.op == "call_module":
for n in node.args:
if n in not_contiguous_list:
# module will just make origin tensor to contiguous
if delete:
not_contiguous_list.remove(n)
elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops):
if node not in not_contiguous_list:
not_contiguous_list.append(node)
return mem
def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size):
if node not in chunk_node_dim:
return 1.0
node_shape = get_node_shape(node)
chunk_dim = chunk_node_dim[node]["chunk_dim"]
if chunk_dim is None:
return 1.0
else:
return chunk_size / float(node_shape[chunk_dim])
def _print_compute_op_mem_log(self, log, nodes, title=None):
if title:
print(title)
for idx, (l, n) in enumerate(zip(log, nodes)):
if n.op in ["placeholder", "get_attr", "output"]:
continue
if any(i in n.name for i in ["getitem", "getattr"]):
continue
print("%s:%.2f \t" % (n.name, l), end="")
if (idx + 1) % 3 == 0:
print("")
print("\n")
def _add_active_nodes_from_list(self, active_nodes: List, nodes: List) -> List:
"""
add active nodes from nodes
"""
for n in nodes:
self._add_active_node(n, active_nodes, 1)
def _get_memory_from_active_nodes(self, active_nodes: Dict) -> float:
"""
sum all memory of active nodes
"""
out = [i for i in active_nodes.values()]
out = sum(out)
return out
def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None, print_mem: bool = False):
"""
Estimate inference memory with chunk
Args:
node_list (List): _description_
chunk_infos (Dict): Chunk information. Defaults to None.
print_mem (bool): Wether to print peak memory of every node. Defaults to False.
Returns:
act_memory_peak_log (List): peak memory of every node
act_memory_after_node_log (List): memory after executing every node
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
act_memory = 0.0
act_memory_peak_log = []
act_memory_after_node_log = []
active_nodes = {}
active_nodes_log = []
not_contiguous_list = []
node_mgr = NodeMgr(node_list)
delete_node_dict = self._build_delete_node_dict(node_mgr)
use_chunk = True if chunk_infos is not None else False
chunk_within = False
chunk_region_idx = None
chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_all = []
if use_chunk:
chunk_regions = [i["region"] for i in chunk_infos]
chunk_starts = [i[0] for i in chunk_regions]
chunk_ends = [i[1] for i in chunk_regions]
chunk_inputs = [i["inputs"] for i in chunk_infos]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_all = [j for i in chunk_inputs for j in i] + [j for i in chunk_inputs_non_chunk for j in i]
chunk_outputs = [i["outputs"] for i in chunk_infos]
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_mgr.get_node_list()):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
chunk_region_idx = chunk_starts.index(idx)
self._add_active_nodes_from_list(active_nodes, chunk_outputs[chunk_region_idx])
# determine chunk ratio for current node
if chunk_within:
chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
chunk_sizes[chunk_region_idx])
# add current node as active node
self._add_active_node(node, active_nodes, chunk_ratio)
act_memory = self._get_memory_from_active_nodes(active_nodes)
# if node is placeholder, just add the size of the node
if node.op == "placeholder":
act_memory_peak_log.append(act_memory)
# skip output
elif node.op == "output":
continue
# no change for non compute node
elif is_non_memory_node(node):
act_memory_peak_log.append(act_memory)
# node is a compute op, calculate tmp
else:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
tmp_memory = self._get_tmp_memory(node, not_contiguous_list, delete=True) * chunk_ratio
# record max act memory
act_memory_peak_log.append(act_memory + tmp_memory)
# remove_deactive_node
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict, kept_nodes=chunk_inputs_all)
# if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends:
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
chunk_within = False
chunk_ratio = 1
chunk_region_idx = None
act_memory = self._get_memory_from_active_nodes(active_nodes)
act_memory_after_node_log.append(act_memory)
active_nodes_log.append(active_nodes.copy())
if print_mem:
print("with chunk" if use_chunk else "without chunk")
self._print_compute_op_mem_log(act_memory_peak_log, node_mgr.get_node_list(), "peak")
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return act_memory_peak_log, act_memory_after_node_log, active_nodes_log
from .trace_indice import TraceIndice
from .utils import NodeMgr
class ReorderGraph(object):
"""
Reorder node list and indice trace list
"""
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.node_mgr = node_mgr
self.all_reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
def _get_reorder_map(self, chunk_info):
reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
chunk_region_start = chunk_info["region"][0]
chunk_region_end = chunk_info["region"][1]
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
chunk_prepose_nodes_idx = [self.node_mgr.find_node_idx(i) for i in chunk_prepose_nodes]
# put prepose nodes ahead
for idx, n in enumerate(chunk_prepose_nodes):
n_idx = chunk_prepose_nodes_idx[idx]
reorder_map[n_idx] = chunk_region_start + idx
# put other nodes after prepose nodes
for n in self.node_mgr.get_node_slice_by_idx(chunk_region_start, chunk_region_end + 1):
if n in chunk_prepose_nodes:
continue
n_idx = self.node_mgr.find_node_idx(n)
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
reorder_map[n_idx] = n_idx + pos
return reorder_map
def _reorder_chunk_info(self, chunk_info, reorder_map):
# update chunk info
chunk_info["region"] = (
chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]),
chunk_info["region"][1],
)
new_inputs_dim = []
for _, input_dim in enumerate(chunk_info["inputs_dim"]):
new_input_dim = {}
for k, v in input_dim.items():
new_input_dim[reorder_map[k]] = v
new_inputs_dim.append(new_input_dim)
chunk_info["inputs_dim"] = new_inputs_dim
return chunk_info
def _update_all_reorder_map(self, reorder_map):
for origin_idx, map_idx in self.all_reorder_map.items():
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
def _reorder_self_node_list(self, reorder_map):
new_node_list = [None for _ in range(len(self.node_mgr.get_node_list()))]
for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = self.node_mgr.get_node_by_idx(old_idx)
self.node_mgr.update_node_list(new_node_list)
def _reorder_idx_trace(self, reorder_map):
# reorder list
new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))]
for old_idx, new_idx in reorder_map.items():
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
self.trace_indice.indice_trace_list = new_idx_trace_list
# update compute
for idx_trace in self.trace_indice.indice_trace_list:
compute = idx_trace["compute"]
for dim_compute in compute:
for idx, i in enumerate(dim_compute):
dim_compute[idx] = reorder_map[i]
# update source
for idx_trace in self.trace_indice.indice_trace_list:
source = idx_trace["source"]
for dim_idx, dim_source in enumerate(source):
new_dim_source = {}
for k, v in dim_source.items():
new_dim_source[reorder_map[k]] = v
source[dim_idx] = new_dim_source
def reorder_all(self, chunk_info):
if chunk_info is None:
return chunk_info
if len(chunk_info["args"]["prepose_nodes"]) == 0:
return chunk_info
reorder_map = self._get_reorder_map(chunk_info)
self._update_all_reorder_map(reorder_map)
self._reorder_idx_trace(reorder_map)
self._reorder_self_node_list(reorder_map)
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
return chunk_info
def reorder_node_list(self, node_list):
new_node_list = [None for _ in range(len(node_list))]
for old_idx, new_idx in self.all_reorder_map.items():
new_node_list[new_idx] = node_list[old_idx]
return new_node_list
def tmp_reorder(self, node_list, chunk_info):
if len(chunk_info["args"]["prepose_nodes"]) == 0:
return node_list, chunk_info
reorder_map = self._get_reorder_map(chunk_info)
# new tmp node list
new_node_list = [None for _ in range(len(node_list))]
for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = node_list[old_idx]
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
return new_node_list, chunk_info
import copy
from typing import Dict, List, Tuple
from torch.fx.node import Node
from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
"""
This is the core class for AutoChunk.
It defines the framework of the strategy of AutoChunk.
Chunks will be selected one by one until search stops.
The chunk search is as follows:
1. find the peak memory node
2. find the max chunk region according to the peak memory node
3. find all possible chunk regions in the max chunk region
4. find the best chunk region for current status
5. goto 1
Attributes:
gm: graph model
print_mem (bool): print estimated memory
trace_index: trace the flow of every dim of every node to find all free dims
trace_flow: determine the region chunk strategy
reorder_graph: reorder nodes to improve chunk efficiency
estimate_memory: estimate memory with chunk
select_chunk: select the best chunk region
Args:
gm: graph model
max_memory (int): max memory in MB
print_mem (bool): print estimated memory
"""
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
self.print_mem = print_mem
self.max_memory = max_memory
self.print_progress = print_progress
self.node_mgr = NodeMgr(list(gm.graph.nodes))
self.trace_indice = TraceIndice(self.node_mgr)
self.estimate_memory = EstimateMemory()
self._init_trace()
self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
self.select_chunk = SelectChunk(
self.trace_indice,
self.estimate_memory,
self.reorder_graph,
self.node_mgr,
max_memory=max_memory,
)
def _init_trace(self) -> None:
"""
find the max trace range for every node
reduce the computation complexity of trace_indice
"""
# find all max ranges
active_nodes = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())[2]
# set trace range and do the trace
if self.print_progress:
get_logger().info("AutoChunk start tracing indice")
self.trace_indice.set_active_nodes(active_nodes)
self.trace_indice.trace_indice()
def _find_peak_region(self, mem_peak: List) -> int:
"""
find peak node, along with its neighbor nodes exceeds max mem
"""
max_value = max(mem_peak)
max_idx = mem_peak.index(max_value)
peak_region = [max_idx, max_idx]
if self.max_memory is None:
return peak_region
# to left
count = 0
for i in range(max_idx - 1, -1, -1):
if mem_peak[i] > self.max_memory:
peak_region[0] = i
else:
count += 1
if count >= 3:
break
# to right
count = 0
for i in range(max_idx + 1, len(mem_peak) - 1):
if mem_peak[i] > self.max_memory:
peak_region[1] = i
count = 0
else:
count += 1
if count >= 3:
break
return peak_region
def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_regions: List = None) -> Tuple:
"""
Search max chunk region according to peak memory node
Chunk region starts extending from the peak node, stops where free var num is min
Args:
active_node (List): active node status for every node
peak_node_idx (int): peak memory node idx
chunk_regions (List): chunk region infos
Returns:
chunk_region_start (int)
chunk_region_end (int)
"""
# check if peak node already in chunk info
if chunk_regions is not None:
for i in chunk_regions:
if i["region"][0] < peak_region[0] <= i["region"][1] or \
i["region"][0] < peak_region[1] <= i["region"][1]:
return None
active_node_num = [len(i) for i in active_node]
window_size = 100
# search min for start
min_num = 1e4
for i in range(peak_region[0], max(peak_region[0] - window_size, -1), -1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_start = i
# search min for end
min_num = 1e4
for i in range(peak_region[1], min(peak_region[1] + window_size, len(active_node_num))):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_end = i
# avoid chunk regions overlap
if chunk_regions is not None:
for i in chunk_regions:
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.node_mgr.get_node_by_idx(end_idx)
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
end_idx):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: Node) -> List:
"""
Search every possible region within the max chunk region.
Args:
max_chunk_region (Tuple)
peak_node (Node): peak memory node
Returns:
possible_chunk_region (List)
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.node_mgr.get_node_list()):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_region[0] + 1):
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
self.node_mgr.get_node_by_idx(end_idx)):
continue
# select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info)
return possible_chunk_region
def _step_search(
self,
mem_peak: List[float],
active_node: List[List[Node]],
chunk_infos: List[Dict],
) -> Dict:
"""
Find one chunk region
The chunk search is as follows:
1. find the peak memory node
2. find the max chunk region according to the peak memory node
3. find all possible chunk regions in the max chunk region
4. find the best chunk region for current status
Args:
mem_peak (List): peak memory for every node
active_node (List[List[Node]]): active node for every node
chunk_infos (List[Dict]): all chunk info
Returns:
best_chunk_region (Dict)
"""
peak_region = self._find_peak_region(mem_peak)
max_chunk_region = self._search_max_chunk_region(active_node, peak_region, chunk_infos)
if max_chunk_region == None:
return None
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_region)
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region
def search_region(self) -> Dict:
"""
Search all chunk regions:
1. Estimate current memory
2. Find best chunk for current memory
3. goto 1
Returns:
chunk_infos (Dict)
"""
if self.print_progress:
get_logger().info("AutoChunk start searching chunk regions")
chunk_infos = []
init_mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
mem_peak = init_mem_peak
while True:
chunk_info = self._step_search(mem_peak, active_node, chunk_infos)
if chunk_info is None:
break
chunk_infos.append(chunk_info)
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
self.node_mgr.get_node_list(), chunk_infos)
if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
chunk_infos,
print_mem=True)
return chunk_infos
from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .trace_indice import TraceIndice
from .utils import NodeMgr, is_non_compute_node
class SelectChunk(object):
def __init__(
self,
trace_indice: TraceIndice,
estimate_memory: EstimateMemory,
reorder_graph: ReorderGraph,
node_mgr: NodeMgr,
max_memory=None,
):
self.trace_indice = trace_indice
self.estimate_memory = estimate_memory
self.reorder_graph = reorder_graph
self.node_mgr = node_mgr
if max_memory is not None:
self.stratge = "fit_memory"
self.max_memory = max_memory # MB
else:
self.stratge = "min_memory"
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
if self.stratge == "min_memory":
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
elif self.stratge == "fit_memory":
best_region = self._select_fit_memory_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
else:
raise RuntimeError()
return best_region
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
# stop chunk if max memory satisfy memory limit
if max(mem_peak) < self.max_memory:
return None
# remove illegal regions
illegal_regions = []
for i in possible_chunk_regions:
if not self._is_legal_region(i, chunk_infos):
illegal_regions.append(i)
for i in illegal_regions:
if i in possible_chunk_regions:
possible_chunk_regions.remove(i)
if len(possible_chunk_regions) == 0:
return None
# get mem for chunk region
regions_dict = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
# no region found
if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.")
# select the min chunk len
chunk_len = [i["chunk_len"] for i in regions_dict]
best_region_idx = chunk_len.index(min(chunk_len))
best_region = regions_dict[best_region_idx]
# get max chunk size
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
return best_region
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
chunk_size = 1
reorder_chunk_info = chunk_region_dict["reorder_chunk_info"]
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_max_mem = 0
# search a region
while cur_chunk_max_mem < self.max_memory:
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
chunk_infos)
return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
if left >= 16:
gap = 4
else:
gap = 1
chunk_info = chunk_region_dict["reorder_chunk_info"]
while right >= left + gap:
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
if cur_chunk_max_mem >= self.max_memory:
right = mid - gap
else:
left = mid + gap
return left
def _get_compute_node_num(self, start, end):
count = 0
for i in self.node_mgr.get_node_slice_by_idx(start, end + 1):
if not is_non_compute_node(i):
count += 1
return count
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
# remove illegal regions
illegal_regions = []
for i in possible_chunk_regions:
if not self._is_legal_region(i, chunk_infos):
illegal_regions.append(i)
for i in illegal_regions:
if i in possible_chunk_regions:
possible_chunk_regions.remove(i)
if len(possible_chunk_regions) == 0:
return None
# get max possible chunk region
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
# get mem for chunk region
regions_dict_list = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
regions_dict_list.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
# select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
best_region = regions_dict_list[best_region_idx]["chunk_info"]
if best_region is not None:
best_region["chunk_size"] = 1
return best_region
def _is_legal_region(self, cur_chunk_info, chunk_infos):
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
if cur_chunk_info in chunk_infos:
return False
if chunk_region_end < chunk_region_start:
return False
for i in chunk_infos:
region = i["region"]
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
(chunk_region_start < region[0] and chunk_region_end < region[0])):
return False
return True
from typing import Dict, List, Tuple
from torch.fx.node import Node
from .trace_indice import TraceIndice
from .utils import (
NodeMgr,
find_chunk_all_input_nodes,
find_chunk_compute_input_and_output_nodes,
find_tensor_shape_node,
flat_list,
get_node_name,
get_node_shape,
is_non_compute_node,
)
class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.node_mgr = node_mgr
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
"""
Check 2 given index: one index should be source of the other
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
# we use start_node_idx instead of real chunk index
start_node_idx = self.node_mgr.find_node_idx(start_node)
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace["source"][end_dim]
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
for node_idx, node_dim in sorted_source:
if node_idx == start_node_idx and start_dim in node_dim:
return True
# it means we meet a node outside the loop, and the node is not input node
if node_idx < start_node_idx:
return False
return False
def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
"""
Check 2 given index: check they haven't been computed in the source trace.
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_compute = end_node_trace["compute"][end_dim]
if any(start_idx <= i <= end_idx for i in end_node_compute):
return False
return True
def _assgin_single_node_flow(
self,
arg_node: Node,
start_idx: int,
end_idx: int,
cur_node: Node,
cur_node_dim: int,
cur_node_compute: Dict,
cur_node_source: Dict,
cur_node_fix_dim: List,
all_node_info: Dict,
next_node_list: List,
) -> bool:
"""
Given the current node and one of its arg node,
this function finds out arg node's chunk dim and fix dim
Args:
arg_node (Node): input node
start_idx (int): chunk region start
end_idx (int): chunk region end
cur_node_dim (int): current node chunk dim
cur_node_compute (Dict): current node compute dict
cur_node_source (Dict): current node source dict
cur_node_fix_dim (List): current node fix dim
all_node_info (Dict): all node chunk info in the chunk region
next_node_list (List)
Returns:
bool: True if this node can be added to the flow, vice versa.
"""
arg_idx = self.node_mgr.find_node_idx(arg_node)
# arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx):
return True
# get fix dim
arg_fix_dim = []
if cur_node_dim is not None:
for i in cur_node_fix_dim:
fix_dim_source = cur_node_source[i]
if arg_idx in fix_dim_source:
arg_fix_dim.append(fix_dim_source[arg_idx][0])
if arg_node in all_node_info:
arg_fix_dim = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
# find arg dim
if cur_node_dim is not None:
# dim is computed
if arg_idx in cur_node_compute[cur_node_dim]:
return False
if arg_idx not in cur_node_source[cur_node_dim]:
arg_dim = None
else:
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
# chunk dim cannot be in fix dims
if arg_dim in arg_fix_dim:
return False
# chunk dim should be None if shape size is 1
if get_node_shape(arg_node)[arg_dim] == 1:
arg_dim = None
# chunk shape should equal cur node
elif get_node_shape(arg_node)[arg_dim] != 1:
if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1:
if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]:
return False
else:
arg_dim = None
# add arg rest dim as fix dim
arg_fix_dim = list(range(len(get_node_shape(arg_node))))
if arg_dim is not None:
arg_fix_dim.remove(arg_dim)
# if already in node_info, arg dim must be same
if arg_node in all_node_info:
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
return False
all_node_info[arg_node]["fix_dim"] = arg_fix_dim
# else add it to list
else:
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
next_node_list.append(arg_node)
return True
def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
next_node_list = []
for cur_node in cur_node_list:
# get cur node info
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
if cur_node_chunk_dim is not None:
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
else:
cur_node_compute = cur_node_source = None
# get all valid args
arg_list = []
for arg in cur_node.all_input_nodes:
if type(arg) != type(cur_node):
continue
if is_non_compute_node(arg):
continue
if get_node_shape(arg) is None:
continue
arg_list.append(arg)
flow_flag = self._assgin_single_node_flow(
arg,
start_idx,
end_idx,
cur_node,
cur_node_chunk_dim,
cur_node_compute,
cur_node_source,
cur_node_fix_dim,
all_node_info,
next_node_list,
)
if flow_flag == False:
return None
cur_node_list = next_node_list
return all_node_info
def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple:
"""
Get chunk dim for every input node for their every entry, remove unchunked nodes
Args:
inputs (List[Node]): input nodes
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
inputs (List(Node)): new inputs
inputs_dim (List): chunk dim for inputs
"""
inputs_dim = []
remove_inputs = []
for input_node in inputs:
input_dict = {}
input_node_idx = self.node_mgr.find_node_idx(input_node)
for user in input_node.users.keys():
# skip non compute
if is_non_compute_node(user):
continue
# untraced node, mostly non compute
if user not in all_node_info:
continue
user_idx = self.node_mgr.find_node_idx(user)
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None:
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
if input_node_idx in user_source:
if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1:
input_dict[user_idx] = [None]
else:
input_dict[user_idx] = user_source[input_node_idx]
else:
return None, None
if len(input_dict) == 0:
remove_inputs.append(input_node)
else:
inputs_dim.append(input_dict)
# remove unchunked inputs
for i in remove_inputs:
if i in inputs:
inputs.remove(i)
return inputs, inputs_dim
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]:
"""
get all useless nodes in chunk region and prepose them
Args:
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
List[Node]: all nodes to be preposed
"""
# get all possible prepose nodes
maybe_prepose_nodes = []
for node, node_info in all_node_info.items():
if node_info["chunk_dim"] is None:
maybe_prepose_nodes.append(node)
for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx):
if node not in all_node_info and node not in chunk_info["outputs"]:
maybe_prepose_nodes.append(node)
maybe_prepose_nodes.sort(
key=lambda x: self.node_mgr.find_node_idx(x),
reverse=True,
) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
tmp_cur_related_prepose_nodes = []
prepose_flag = True
# loop cur node's all arg until out of chunk
while len(tmp_cur_prepose_nodes) > 0:
if prepose_flag == False:
break
tmp_next_prepose_nodes = []
tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes)
for cur_prepose_node in tmp_cur_prepose_nodes:
if prepose_flag == False:
break
for cur_prepose_node_arg in cur_prepose_node.all_input_nodes:
if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue
# out of loop
if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx):
continue
# compute op in loop
elif cur_prepose_node_arg in all_node_info:
if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
else:
prepose_flag = False
break
# non compute op
else:
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
tmp_cur_prepose_nodes = tmp_next_prepose_nodes
if prepose_flag == False:
maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
continue
else:
for n in tmp_cur_related_prepose_nodes:
if n not in prepose_nodes:
prepose_nodes.append(n)
if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n)
# sort by index
prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x))
chunk_info["args"]["prepose_nodes"] = prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n)
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
for i in non_chunk_inputs:
if i not in chunk_info["inputs"]:
chunk_info["inputs_non_chunk"].append(i)
return chunk_info
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
if all_node_info is None:
return None
chunk_info = {
"region": (start_idx, end_idx),
"inputs": [],
"inputs_non_chunk": [],
"inputs_dim": [],
"outputs": [self.node_mgr.get_node_by_idx(end_idx)],
"outputs_non_tensor": {},
"outputs_dim": [end_dim],
"node_chunk_dim": all_node_info,
"args": {},
}
# find chunk info for other outputs
if len(find_tensor_shape_node(outputs)) > 1:
chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info)
if chunk_info is None:
return None
# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None
chunk_info["inputs"] = inputs
chunk_info["inputs_dim"] = inputs_dim
# move useless nodes ahead of loop
self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info)
# find non chunk inputs
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
# reassgin reshape size, some size may have changed due to chunk
chunk_info = self._reassgin_reshape_size(chunk_info)
return chunk_info
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
chunk_info: Dict):
start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs
for output in outputs:
output_legal = False
output_idx = self.node_mgr.find_node_idx(output)
# skip the origin output
if output_idx == end_idx:
continue
# skip non tensor
if get_node_shape(output) is None:
# log shape tensor
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
continue
# loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))):
if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx):
continue
new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx)
if new_all_node_info is None:
continue
# check node info legal
if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True:
output_legal = True
break
# not legal
if output_legal == False:
return None
return chunk_info
def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool:
"""
check if there is conflict between new node info and old chunk info. If not, update old chunk info
"""
# check if conflict
overlap_flag = False
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
overlap_flag = True
if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]:
return False
# if no overlap, we just consider them as prepose nodes, instead of new output
if overlap_flag == False:
return True
# update chunk info
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
else:
chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output)
chunk_info["outputs_dim"].append(output_dim)
return True
def _reassgin_reshape_size(self, chunk_info):
"""
Some shape args in reshape may have changed due to chunk
reassgin those changed shape
"""
chunk_region = chunk_info["region"]
reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]]
for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1):
if any(i == get_node_name(node) for i in ["reshape", "view"]):
if node in chunk_info["args"]["prepose_nodes"]:
continue
if node.args[0] in chunk_info["inputs_non_chunk"]:
continue
reshape_args = flat_list(node.args[1:])
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
reshape_args[0].meta['fwd_out']) > 1:
continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
if reshape_arg_dim == chunk_dim:
new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape
else:
if isinstance(reshape_arg, int):
new_shape += "%s, " % str(reshape_arg)
else:
new_shape += "%s, " % reshape_arg.name
new_shape = new_shape[:-2]
origin_shape = str(reshape_args)[1:-1]
reshape_size[node.name] = [origin_shape, new_shape]
chunk_info["reshape_size"] = reshape_size
return chunk_info
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
"""
check if region start and end is legal
"""
# dim cannot be None
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
return False
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
return False
# must have users
if len(end_node.users) == 0:
return False
# check index source align
if not self.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
return False
# check index compute
if not self.check_index_compute(start_idx, end_dim, end_node, end_idx):
return False
return True
import copy
from typing import Dict, List, Tuple
from torch.fx.node import Node
from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape
class TraceIndice(object):
"""
Trace all indice information for every node.
Indice is a logical concept. Equal dims can been treated as one indice.
eg. dim(x1) = [a, b, c]
dim(x2) = [d, e, f]
and we have x3 = x1 * x2.
then a=d, b=e, c=f, due to the broadcast property,
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
This class will record every node's dims' indice, compute and source.
Attibutes:
node_list (List)
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
indice_view_list (Dict): not used for now
indice_count (int): record indice number
Args:
node_list (List)
"""
def __init__(self, node_mgr: NodeMgr) -> None:
self.node_mgr = node_mgr
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
self.indice_count = -1
self.active_node_list = []
def _init_indice_trace_list(self) -> List:
indice_trace_list = []
for n in self.node_mgr.get_node_list():
if get_node_shape(n) != None:
cur_trace = {
"indice": [None for _ in range(len(get_node_shape(n)))],
"compute": [[] for _ in range(len(get_node_shape(n)))],
"source": [{} for _ in range(len(get_node_shape(n)))],
}
else:
cur_trace = {"indice": [], "compute": [], "source": []}
indice_trace_list.append(cur_trace)
return indice_trace_list
def set_active_nodes(self, active_node_list: List) -> None:
self.active_node_list = active_node_list
def _add_indice(self) -> int:
"""
Update the count and return it. To record the idx number.
Returns:
indice_count: int
"""
self.indice_count += 1
return self.indice_count
def _del_dim(self, idx: int, dim_idx: int) -> None:
"""
delete a dim for indice, compute and source
"""
self.indice_trace_list[idx]["indice"].pop(dim_idx)
self.indice_trace_list[idx]["compute"].pop(dim_idx)
self.indice_trace_list[idx]["source"].pop(dim_idx)
def _add_dim(self, node_idx: int, dim_idx: int) -> None:
"""
add a dim for indice, compute and source
"""
# need to remap if dim_idx < 0, e.g. -1
if dim_idx < 0:
dim_idx = list(range(len(self.indice_trace_list[node_idx]["indice"]) + 1))[dim_idx]
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
def _add_source(
self,
node_from: Node,
node_from_dim: int,
node_to: Node,
node_to_dim: int,
init=False,
) -> None:
node_from_dim = self._transform_indice(node_from, node_from_dim)
node_from_trace_source = self._find_source_trace_from_node(node_from)
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_to_trace_source = self._find_source_trace_from_node(node_to)
node_from_idx = self.node_mgr.find_node_idx(node_from)
if init:
node_to_trace_source[node_to_dim] = {}
# add dim to cur new source
if node_from_idx not in node_to_trace_source[node_to_dim]:
node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
else:
if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim)
# update inputs source
for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
if node_idx not in node_to_trace_source[node_to_dim]:
node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim)
else:
for d in node_dim:
if d not in node_to_trace_source[node_to_dim][node_idx]:
node_to_trace_source[node_to_dim][node_idx].append(d)
def _transform_indice(self, node: Node, node_dim: int) -> int:
node_idx = self._find_indice_trace_from_node(node)
dims = list(range(len(node_idx)))
return dims[node_dim]
def _inherit_indice(
self,
node_from: Node,
node_from_dim: int,
node_to: Node,
node_to_dim: int,
init: bool = True,
) -> None:
"""
node_to's node_to_dim inherit node_from's node_from_dim by indice, compute and source
"""
node_from_dim = self._transform_indice(node_from, node_from_dim)
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to)
if init:
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
else:
for j in node_from_trace["compute"][node_from_dim]:
if j not in node_to_trace["compute"][node_to_dim]:
node_to_trace["compute"][node_to_dim].append(j)
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init)
def _inherit_all_indice(self, node_from: Node, node_to: Node) -> None:
"""
inherit all dims with init
"""
# find indice just for assert length
node_from_indice = self._find_indice_trace_from_node(node_from)
node_to_indice = self._find_indice_trace_from_node(node_to)
assert len(node_from_indice) == len(node_to_indice)
for i in range(len(node_from_indice)):
self._inherit_indice(node_from, i, node_to, i, init=True)
def _inherit_more_indice_from_node_with_exclude(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
"""
inherit indice from node without init
"""
if exclude == None:
exclude = []
else:
exclude = [self._transform_indice(node_to, i) for i in exclude]
node_from_compute = self._find_compute_trace_from_node(node_from)
node_to_compute = self._find_compute_trace_from_node(node_to)
# assert len(node_from_compute) == len(node_to_compute)
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
if self._transform_indice(node_to, i) in exclude:
continue
self._inherit_indice(node_from, i, node_to, i, init=False)
def _mark_computation(self, node: Node, idx: int, dim: int) -> None:
"""
Mark some dims of node as computed.
Args:
node (node)
idx (int): node index
dim (list or int): dims to be marked as computed
"""
if isinstance(dim, int):
dim = [dim]
dims = list(range(len(get_node_shape(node))))
for d in dim:
cur_dim = dims[d]
if idx not in self.indice_trace_list[idx]["compute"][cur_dim]:
self.indice_trace_list[idx]["compute"][cur_dim].append(idx)
def _find_trace_from_node(self, node: Node) -> Dict:
"""
Find node idx and compute trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = self.node_mgr.find_node_idx(node)
node_dict = self.indice_trace_list[node_idx]
return node_dict
def _find_source_trace_from_node(self, node: Node) -> List:
"""
Find node source trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = self.node_mgr.find_node_idx(node)
node_dict = self.indice_trace_list[node_idx]
return node_dict["source"]
def _find_indice_trace_from_node(self, node) -> List:
"""
Find node idx trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
"""
node_idx = self.node_mgr.find_node_idx(node)
return self.indice_trace_list[node_idx]["indice"]
def _find_compute_trace_from_node(self, node: Node) -> List:
"""
Find node compute trace by the node.
Args:
node (node)
Returns:
compute (list): computed idx of the node.
"""
node_idx = self.node_mgr.find_node_idx(node)
return self.indice_trace_list[node_idx]["compute"]
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
"""
Assign node's trace as its input node.
Args:
node (node)
node_idx (int)
"""
if input_node == None:
input_node = find_first_tensor_arg(node)
self._inherit_all_indice(input_node, node)
def _assign_all_indice(self, node: Node, node_idx: int) -> None:
"""
Add new indice for all node's dims.
Args:
node (node)
node_idx (int)
"""
shape = node.meta["tensor_meta"].shape
if shape is None:
return
new_trace = []
for _ in shape:
new_trace.append(self._add_indice())
self.indice_trace_list[node_idx]["indice"] = new_trace
def _assign_transpose_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for transpose op.
1. swap input's dim according to transpose args
2. inherit input's computation
Args:
node (node)
node_idx (int)
"""
input_node = node.args[0]
tranpose_dim = node.args[1:]
self._assign_indice_as_input(node, node_idx, input_node)
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
def _assign_permute_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for permute op.
1. swap input's dim according to permute args
2. inherit input's computation
Args:
node (node)
node_idx (int)
"""
permute_dim = flat_list(node.args[1:])
input_node = node.args[0]
self._assign_indice_as_input(node, node_idx, input_node)
for idx, d in enumerate(permute_dim):
self._inherit_indice(input_node, d, node, idx)
def _assign_linear_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for linear op.
1. copy trace from input node and change last indice according to weight
2. mark equal for input node last indice, weight first dim and bias dim.
3. inherit input's computation, mark computation for last dim.
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, node_idx)
if len(node.args) >= 2:
weight = node.args[1]
self._inherit_indice(weight, 1, node, -1)
else:
self._del_dim(node_idx, -1)
self._add_dim(node_idx, -1)
self._mark_computation(node, node_idx, [-1])
def _assign_addmm_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for addmm op.
Args:
node (node)
node_idx (int)
"""
bias, input_node, weight = node.args
assert len(get_node_shape(bias)) == 1 and len(get_node_shape(weight)) == 2
self._assign_indice_as_input(node, node_idx, input_node)
self._inherit_indice(weight, 1, node, -1)
self._inherit_more_indice_from_node_with_exclude(bias, node)
self._mark_computation(node, node_idx, [-1])
def _assign_baddbmm_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for baddbmm(batch add and batch matmul) op.
add, matmul_left, matmul_right = args
out = add + (matmul_left x matmul_right)
Args:
node (node)
node_idx (int)
"""
add, matmul_left, matmul_right = node.args
assert get_node_shape(add) == get_node_shape(node)
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
self._assign_indice_as_input(node, node_idx, matmul_left)
# matmul
self._inherit_indice(matmul_right, -1, node, -1)
self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-2, -1])
self._mark_computation(node, node_idx, [-1])
# add
self._inherit_more_indice_from_node_with_exclude(add, node)
def _assign_matmul_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for matmul op.
1. copy trace from matmul_left and change last indice according to matmul_right. (assert they have same length)
2. mark equal for input matmul_left -1 indice and matmul_right -2 dim.
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
Args:
node (node)
node_idx (int)
"""
matmul_left, matmul_right = node.args
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
self._assign_indice_as_input(node, node_idx, matmul_left)
self._inherit_indice(matmul_right, -1, node, -1)
self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-1, -2])
self._mark_computation(node, node_idx, [-1])
def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for conv2d op.
Args:
node (node)
node_idx (int)
"""
# get conv module
node_targets = node.target.split(".")
conv_module = node.graph.owning_module
for i in node_targets:
conv_module = getattr(conv_module, i)
assert conv_module.dilation == (1, 1), "dilation for conv2d not implemented"
# get conv input
assert len(node.args) == 1
input_node = node.args[0]
assert len(get_node_shape(input_node)) == 4
# assgin index
self._assign_indice_as_input(node, node_idx, input_node)
self._del_dim(node_idx, 1)
self._add_dim(node_idx, 1)
self._mark_computation(node, node_idx, [1, 2, 3])
def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for interpolate op.
Args:
node (node)
node_idx (int)
"""
# get conv input
assert node.kwargs['size'] is None
assert len(get_node_shape(node)) == 4
# assgin index
self._assign_indice_as_input(node, node_idx)
self._mark_computation(node, node_idx, [-1, -2])
def _assign_layernorm_indice(self, node, idx):
"""
Assign indice for layernorm op.
1. assign indice as input node
2. inherit computation and mark last 2 dims as computed.
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [-1])
def _assign_groupnorm_indice(self, node, idx):
"""
Assign indice for groupnorm op.
Args:
node (node)
node_idx (int)
"""
assert len(get_node_shape(node)) == 4
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [-1, -2, -3])
def _assign_elementwise_indice(self, node, idx):
"""
Assign indice for element-wise op (eg. relu sigmoid add mul).
1. assign indice as input node
2. inherit computation from all input nodes.
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, idx)
nodes_in = []
for node_in in node.args:
if type(node_in) == type(node):
nodes_in.append(node_in)
self._inherit_more_indice_from_node_with_exclude(node_in, node)
def _assgin_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
self._inherit_more_indice_from_node_with_exclude(node_in, node)
def _assign_einsum_indice(self, node, idx):
"""
Assign indice for einsum op.
Args:
node (node)
node_idx (int)
"""
patterns = node.args[0]
input_nodes = node.args[1:]
patterns = patterns.replace(" ", "")
left, right = patterns.split("->")
left = left.split(",")
if "..." in right:
replace_list = "!@#$%^&*"
target_len = len(get_node_shape(node))
add_len = target_len - len(right) + 3
replace_str = replace_list[:add_len]
right = right.replace("...", replace_str)
for ll in range(len(left)):
left[ll] = left[ll].replace("...", replace_str)
all_index = []
for i in left:
for c in i:
all_index.append(c)
all_index = set(all_index)
for right_idx, right_indice in enumerate(right):
for left_idx, left_str in enumerate(left):
if right_indice in left_str:
source_idx = left_str.index(right_indice)
self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx)
def _assign_softmax_indice(self, node, idx):
"""
Assign indice for softmax op.
1. assign indice as input node
2. inherit computation and mark softmax dim as computed.
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [node.kwargs["dim"]])
def _assign_split_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for split op.
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, node_idx)
dim_idx = node.kwargs["dim"]
self._del_dim(node_idx, dim_idx)
self._add_dim(node_idx, dim_idx)
def _assign_unsqueeze_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
Args:
node (node)
node_idx (int)
"""
self._del_dim(node_idx, -1)
self._assign_indice_as_input(node, node_idx)
dim_idx = node.args[1]
# unsqueeze(-1) = unsqueeze(shape_num + 1)
if dim_idx < 0:
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
self._add_dim(node_idx, dim_idx)
def _assign_cat_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for cat op.
Args:
node (node)
node_idx (int)
"""
nodes_in = flat_list(node.args[0])
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._inherit_more_indice_from_node_with_exclude(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
self._add_dim(node_idx, cat_dim)
def _assign_sum_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for sum op.
Args:
node (node)
node_idx (int)
"""
nodes_in = flat_list(node.args[0])
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._inherit_more_indice_from_node_with_exclude(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
def _assign_flatten_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for flatten op.
Args:
node (node)
node_idx (int)
"""
nodes_in = node.args[0]
nodes_in_shape = get_node_shape(nodes_in)
flatten_start_dim = node.args[1]
flatten_dim_num = len(nodes_in_shape) - flatten_start_dim - 1
assert flatten_dim_num > 0
for _ in range(flatten_dim_num):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx, nodes_in)
for _ in range(flatten_dim_num + 1):
self._del_dim(node_idx, -1)
self._add_dim(node_idx, -1)
def _assign_expand_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for expand op.
Args:
node (node)
node_idx (int)
"""
expand_shape = node.args[1:]
node_in_shape = get_node_shape(node.args[0])
assert len(expand_shape) == len(node_in_shape)
self._assign_indice_as_input(node, node_idx)
for i in range(len(node_in_shape)):
if expand_shape[i] == node_in_shape[i] or expand_shape[i] == -1:
continue
elif expand_shape[i] > node_in_shape[i]:
self._del_dim(node_idx, i)
self._add_dim(node_idx, i)
else:
raise RuntimeError()
def _assign_unbind_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for unbind op.
Args:
node (node)
node_idx (int)
"""
unbind_dim = node.args[1]
self._add_dim(node_idx, unbind_dim)
self._assign_indice_as_input(node, node_idx)
self._del_dim(node_idx, unbind_dim)
def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for embedding op.
Args:
node (node)
node_idx (int)
"""
self._del_dim(node_idx, -1)
self._assign_indice_as_input(node, node_idx)
self._add_dim(node_idx, -1)
def _assign_getitem_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for getitem.
getitem can act like slice sometimes
Args:
node (node)
node_idx (int)
"""
node_args = flat_list(node.args[1:])
# deal with split
if get_node_name(node.args[0]) == "split":
self._assign_indice_as_input(node, node_idx)
self._del_dim(node_idx, node.args[0].kwargs["dim"])
self._add_dim(node_idx, node.args[0].kwargs["dim"])
return
# skip non tensor
if get_node_shape(node) is None:
return
# find if slice
flag = False
for node_arg in node_args:
node_arg_str = str(node_arg)
if any(i == node_arg_str for i in ["None", "Ellipsis"]):
flag = True
break
if "slice" in node_arg_str:
flag = True
break
if flag == False:
return
# node args should be like [Ellipsis, slice(start, step, end), None]
node_shape = get_node_shape(node)
origin_idx_count = 0
new_idx_count = 0
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
for _ in range(new_dim_num):
self._del_dim(node_idx, 0)
delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args])
for _ in range(delete_dim_num):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
for _, node_arg in enumerate(node_args):
node_arg_str = str(node_arg)
# Ellipsis means [..., ]
if "Ellipsis" == node_arg_str:
shape_gap = len(node_shape) - len(node_args) + 1
origin_idx_count += shape_gap
new_idx_count += shape_gap
# slice(None, None, None) means all indexes
elif "slice" in node_arg_str:
if "slice(None, None, None)" != node_arg_str:
self._del_dim(node_idx, new_idx_count)
self._add_dim(node_idx, new_idx_count)
origin_idx_count += 1
new_idx_count += 1
# None means a new dim
elif "None" == node_arg_str:
self._add_dim(node_idx, new_idx_count)
new_idx_count += 1
elif "0" == node_arg_str:
self._del_dim(node_idx, new_idx_count)
origin_idx_count += 1
else:
raise NotImplementedError()
def _assign_view_reshape_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for view and reshape op.
1. get origin shape and target shape by meta info.
2. compute the real value of -1 in target shape.
3. determine changed dim, and assign indice for generated dim.
4. log changed dim and generated dim for restore
5. inherit computation.
6. look into view list to see whether the view is associated with other,
if so assign equal dim according to previous view.
Args:
node (node)
node_idx (int)
"""
# get data, turn into number
origin_node = node.args[0]
origin_shape = origin_node.meta["tensor_meta"].shape
target_shape = []
unflated_args = flat_list(node.args)
for i in range(1, len(unflated_args)):
if isinstance(unflated_args[i], int):
target_shape.append(unflated_args[i])
else:
target_shape.extend(unflated_args[i].meta["fwd_out"])
# compute the value of -1
if -1 in target_shape:
origin_product = 1
for i in origin_shape:
origin_product *= i
target_product = -1
for i in target_shape:
target_product *= i
shape_idx = target_shape.index(-1)
target_shape[shape_idx] = origin_product // target_product
# find same dim
dim_to_same_dim = []
dim_from_same_dim = []
for i in range(len(origin_shape)):
if origin_shape[i] == target_shape[i]:
dim_to_same_dim.append(i)
dim_from_same_dim.append(i)
else:
break
for i in range(-1, -len(origin_shape), -1):
if origin_shape[i] == target_shape[i]:
dim_to_same_dim.append(len(target_shape) + i)
dim_from_same_dim.append(len(origin_shape) + i)
else:
break
dim_from = list(set(range(len(origin_shape))) - set(dim_from_same_dim))
dim_to = list(set(range(len(target_shape))) - set(dim_to_same_dim))
assert len(dim_from) == 1 or len(dim_to) == 1 or len(dim_from) == len(dim_to)
dim_diff = len(dim_from) - len(dim_to)
if dim_diff > 0:
# dim merge
for i in range(dim_diff):
self._add_dim(node_idx, -1)
elif dim_diff < 0:
# dim expand
for i in range(-dim_diff):
self._del_dim(node_idx, -1)
# get new indice
origin_trace = self._find_indice_trace_from_node(origin_node)
self._assign_indice_as_input(node, node_idx, origin_node)
dim_from.reverse()
for i in dim_from:
self._del_dim(node_idx, i)
for i in dim_to:
self._add_dim(node_idx, i)
dim_from.reverse()
# inheirt indice from current node
if len(dim_from) != 0 and len(dim_to) != 0:
if dim_diff == 1:
if origin_shape[dim_from[0]] == 1:
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
elif origin_shape[dim_from[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
elif dim_diff == -1:
if target_shape[dim_to[0]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
elif target_shape[dim_to[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# log view, not used now
view_dict = {
"idx_from": [origin_trace[i] for i in dim_from],
"dim_from": dim_from,
"idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to],
"dim_to": dim_to,
}
self.indice_view_list[node] = view_dict
def _clear_trace(self, node_idx: int) -> None:
"""
clear too far trace to speed up computation
"""
trace_barrier = max(node_idx - 100, 0)
active_nodes = self.active_node_list[trace_barrier]
active_nodes = [self.node_mgr.find_node_idx(i) for i in active_nodes.keys()]
trace = self.indice_trace_list[node_idx]
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
dim_compute.pop(i)
continue
# clear source
for dim_source in trace["source"]:
for k in list(dim_source.keys()):
if k < trace_barrier and k not in active_nodes:
dim_source.pop(k)
def trace_indice(self) -> None:
for idx, node in enumerate(self.node_mgr.get_node_list()):
node_name = get_node_name(node)
if node.op == "placeholder":
self._assign_all_indice(node, idx)
elif node.op == "call_method":
if "transpose" == node_name:
self._assign_transpose_indice(node, idx)
elif "permute" == node_name:
self._assign_permute_indice(node, idx)
elif "view" == node_name or "reshape" == node_name:
self._assign_view_reshape_indice(node, idx)
elif "unsqueeze" == node_name:
self._assign_unsqueeze_indice(node, idx)
elif "split" == node_name:
self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
self._assgin_no_change_indice(node, idx)
elif "new_ones" == node_name:
self._assign_all_indice(node, idx)
elif "flatten" == node_name:
self._assign_flatten_indice(node, idx)
elif "expand" == node_name:
self._assign_expand_indice(node, idx)
elif "unbind" == node_name:
self._assign_unbind_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
elif any(i == node_name for i in ["size"]):
continue
else:
raise NotImplementedError(node_name, "method not implemented yet!")
elif node.op == "call_function":
if "linear" == node_name:
self._assign_linear_indice(node, idx)
elif "cat" == node_name:
self._assign_cat_indice(node, idx)
elif any(n == node_name for n in ["matmul", "bmm"]):
self._assign_matmul_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
elif any(n == node_name for n in [
"mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
"sin", "cos"
]):
self._assign_elementwise_indice(node, idx)
elif "einsum" == node_name:
self._assign_einsum_indice(node, idx)
elif "sum" == node_name:
self._assign_sum_indice(node, idx)
elif "layer_norm" == node_name:
self._assign_layernorm_indice(node, idx)
elif "getitem" == node_name:
self._assign_getitem_indice(node, idx)
elif "addmm" == node_name:
self._assign_addmm_indice(node, idx)
elif "baddbmm" == node_name:
self._assign_baddbmm_indice(node, idx)
elif "interpolate" == node_name:
self._assign_interpolate_indice(node, idx)
elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor", "empty"]):
self._assign_all_indice(node, idx)
elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]):
continue
else:
raise NotImplementedError(node_name, "function not implemented yet!")
elif node.op == "call_module":
node_name = get_module_node_name(node)
if "layernorm" == node_name:
self._assign_layernorm_indice(node, idx)
elif "groupnorm" == node_name:
self._assign_groupnorm_indice(node, idx)
elif "embedding" == node_name:
self._assign_embedding_indice(node, idx)
elif "linear" == node_name:
self._assign_linear_indice(node, idx)
elif "conv2d" == node_name:
self._assign_conv2d_indice(node, idx)
elif "identity" == node_name:
self._assgin_no_change_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
self._assign_elementwise_indice(node, idx)
else:
raise NotImplementedError(node_name, "module not implemented yet!")
elif node.op == "get_attr":
self._assign_all_indice(node, idx) # get param
elif node.op == "output":
continue
else:
raise NotImplementedError(node.op, "op not implemented yet!")
# limit trace range
self._clear_trace(idx)
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from torch.fx.node import Node
from colossalai.logging import get_dist_logger
NON_COMPUTE_OP = ["placeholder", "get_attr", "output"]
NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"]
logger = get_dist_logger()
class NodeMgr(object):
def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list
self._node_dict = {}
self._set_node_dict()
def _set_node_dict(self) -> None:
"""
create a dict {node_name: node_idx}
"""
self._node_dict.clear()
for idx, node in enumerate(self._node_list):
self._node_dict[node.name] = idx
def find_node_idx(self, node: Node) -> int:
"""
find node's index
"""
return self._node_dict[node.name]
def find_node_idx_by_name(self, node_name: str) -> int:
"""
find node's index
"""
return self._node_dict[node_name]
def get_node_by_idx(self, idx: int) -> Node:
"""
get a node by index
"""
return self._node_list[idx]
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
"""
get a slice of node by index
"""
return self._node_list[start:end]
def get_node_list(self) -> List:
"""
get full node list
"""
return self._node_list
def update_node_list(self, node_list: List) -> None:
"""
update node list, reset node dict
"""
self._node_list = node_list
self._set_node_dict()
def get_logger() -> Any:
return logger
def flat_list(inputs: Any) -> List:
"""
flat a list by recursion
"""
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
return [inputs]
res = []
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
res.extend(flat_list(i))
elif isinstance(i, dict):
res.extend(flat_list(list(i.keys())))
else:
res.append(i)
return res
def find_first_tensor_arg(node: Node) -> Node:
"""
Find the first input tensor arg for a node
"""
for arg in node.args:
if type(arg) == type(node):
return arg
raise RuntimeError()
def is_non_compute_node(node: Node) -> bool:
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
return True
if "getitem" in node.name:
if get_node_shape(node) is not None:
return False
node_args = flat_list(node.args[1:])
for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
return False
if "slice" in str(node_arg):
return False
return True
return False
def get_node_shape(node: Node) -> Any:
"""
return node data shape
"""
if get_node_name(node) in ["split", "unbind"]:
return node.meta["tensor_meta"][0].shape
if hasattr(node.meta["tensor_meta"], "shape"):
return node.meta["tensor_meta"].shape
return None
def is_non_memory_node(node: Node) -> bool:
if "getitem" in node.name:
return True
if "output" in node.op:
return True
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder(node: Node) -> bool:
if "placeholder" in node.op:
return False
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
if "output" in node.op:
return False
return is_non_compute_node_except_placeholder(node)
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
for key, value in user_to_last_uses.items():
for n in value:
if n.op == "placeholder":
user_to_last_uses[key].remove(n)
def find_chunk_all_input_nodes(nodes: List[Node]) -> List:
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes = []
for node in nodes:
for input_node in node._input_nodes.keys():
if input_node not in nodes and input_node not in input_nodes:
input_nodes.append(input_node)
return input_nodes
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]:
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes = []
output_nodes = []
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
if (input_node not in nodes and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
if (output_node not in nodes and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)):
output_nodes.append(node)
return input_nodes, output_nodes
def get_module_node_name(node: Node) -> str:
"""
get module class name
"""
node_targets = node.target.split(".")
module = node.graph.owning_module
for i in node_targets:
module = getattr(module, i)
module_name = str(module.__class__).split(".")[-1][:-2]
module_name = module_name.lower()
return module_name
def get_node_name(node: Node) -> str:
"""
get node name
"""
node_name = node.name
if "_" in node_name:
for i in range(len(node_name) - 1, -1, -1):
if node_name[i] == "_":
node_name = node_name[:i]
break
elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]:
continue
else:
break
return node_name
def find_tensor_node(node_list: List[Node]) -> List[Node]:
"""
find tensor nodes from a node list
"""
out = []
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
return out
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
"""
find tensor and shape nodes from a node list
"""
out = []
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
node.meta['fwd_out'][0], int):
out.append(node)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment