Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
......@@ -3,16 +3,18 @@
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
# to support tensor parallel
import torch
from collections import defaultdict, abc
import warnings
from collections import abc, defaultdict
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from colossalai.context import ParallelMode
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from packaging import version
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class _MultiDeviceReplicator(object):
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
import torch.cuda.amp as torch_amp
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from ._grad_scaler import GradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
from ._grad_scaler import GradScaler
class TorchAMPOptimizer(ColossalaiOptimizer):
"""A wrapper class which integrate Pytorch AMP with an optimizer
......
from .ckpt_solver_base import CheckpointSolverBase
from .ckpt_solver_chen import CheckpointSolverChen
from .ckpt_solver_rotor import CheckpointSolverRotor
import os
from setuptools import Extension, setup
this_dir = os.path.dirname(os.path.abspath(__file__))
ext_modules = [Extension(
'rotorc',
sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
)]
setup(
name='rotor c extension',
version='0.1',
description='rotor c extension for faster dp computing',
ext_modules=ext_modules,
)
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, List
import torch
from torch.fx import Graph, Node
from colossalai.auto_parallel.passes.runtime_apply_pass import (
runtime_apply,
runtime_apply_for_iterable_object,
runtime_comm_spec_apply,
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
__all___ = ['CheckpointSolverBase']
def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes):
if n_src.op == 'output':
n_dst.meta = n_src.meta
def _get_param_size(module: torch.nn.Module):
"""Get the size of the parameters in the module"""
return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()])
class CheckpointSolverBase(ABC):
def __init__(
self,
graph: Graph,
free_memory: float = -1.0,
requires_linearize: bool = False,
cnode: List[str] = None,
optim_multiplier: float = 1.0,
):
"""``CheckpointSolverBase`` class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for target
computing graph.
Existing Solvers:
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor)
Args:
graph (Graph): The computing graph to be optimized.
free_memory (float): Memory constraint for the solution.
requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
Warnings:
Meta information of the graph is required for any ``CheckpointSolver``.
"""
# super-dainiu: this graph is a temporary graph which can refer to
# the owning module, but we will return another deepcopy of it after
# the solver is executed.
self.graph = deepcopy(graph)
self.graph.owning_module = graph.owning_module
_copy_output(graph, self.graph)
self.graph.set_codegen(ActivationCheckpointCodeGen())
# check if has meta information
if any(len(node.meta) == 0 for node in self.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
)
# parameter memory = parameter size + optimizer extra weight storage
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
self.cnode = cnode
self.requires_linearize = requires_linearize
if self.requires_linearize:
self.node_list = self._linearize_graph()
else:
self.node_list = self.get_node_list()
@abstractmethod
def solve(self):
"""Solve the checkpointing problem and return the solution.
"""
pass
def get_node_list(self):
"""Get the node list.
"""
return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]:
"""Linearizing the graph
Args:
graph (Graph): The computing graph to be optimized.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
Do merge the inplace ops and shape-consistency ops into the previous node.
"""
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
common_ops = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in common_ops
else:
return target.__name__ in common_ops
def _is_sink() -> bool:
"""Check if we can free all dependencies
Returns:
bool
"""
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
def _is_shape_consistency(n: Node):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
map(_is_shape_consistency, n.users))
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
self.cnode = []
deps = {}
node_list = []
region = []
for n in self.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n.all_input_nodes:
if n_par.op != "placeholder" and n_par.name not in self.cnode:
deps[n_par] -= 1
region.append(n)
# if the node could free all dependencies in graph
# we could begin a new node
if _is_sink():
node_list.append(region)
region = []
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
return node_list
import math
from copy import deepcopy
from typing import List, Set, Tuple
from torch.fx import Graph, Node
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
from .ckpt_solver_base import CheckpointSolverBase
__all__ = ['CheckpointSolverChen']
class CheckpointSolverChen(CheckpointSolverBase):
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using ``CheckpointSolverChen``:
>>> solver = CheckpointSolverChen(gm.graph)
>>> chen_graph = solver.solve()
>>> gm.graph = chen_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
num_grids (int, optional): Number of grids to search for b. Defaults to 6.
"""
super().__init__(graph, 0, 0, True, cnode)
self.num_grids = num_grids
def solve(self) -> Graph:
"""Solve the checkpointing problem using Algorithm 3.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
ckpt = self.grid_search()
for i, seg in enumerate(ckpt):
for idx in range(*seg):
nodes = self.node_list[idx]
for n in nodes:
if n.op in checkpointable_op:
n.meta['activation_checkpoint'] = i
return deepcopy(self.graph)
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt_intv = []
temp = 0
x = 0
y = 0
prev_idx = 2
for idx, nodes in enumerate(self.node_list):
for n in nodes:
n: Node
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
y = max(y, temp)
if temp > b and idx > prev_idx:
x += calculate_fwd_in(nodes[0])
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
return ckpt_intv, math.floor(math.sqrt(x * y))
def grid_search(self) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
"""
_, b_approx = self.run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // self.num_grids):
ckpt_intv, b_approx = self.run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt_intv
return ckpt_opt
#define PY_SSIZE_T_CLEAN
#include <Python.h>
long* PySequenceToLongArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
long* result = (long*)calloc(len + 1, sizeof(long));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyLong_AsLong(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
double* PySequenceToDoubleArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
double* result = (double*)calloc(len + 1, sizeof(double));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyFloat_AsDouble(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
long* getLongArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
long* result = PySequenceToLongArray(sequence);
Py_DECREF(sequence);
return result;
}
double* getDoubleArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
double* result = PySequenceToDoubleArray(sequence);
Py_DECREF(sequence);
return result;
}
static PyObject* computeTable(PyObject* self, PyObject* args) {
PyObject* chainParam;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chainParam, &mmax)) return NULL;
double* ftime = getDoubleArray(chainParam, "ftime");
if (!ftime) return NULL;
double* btime = getDoubleArray(chainParam, "btime");
if (!btime) return NULL;
long* x = getLongArray(chainParam, "x");
if (!x) return NULL;
long* xbar = getLongArray(chainParam, "xbar");
if (!xbar) return NULL;
long* ftmp = getLongArray(chainParam, "btmp");
if (!ftmp) return NULL;
long* btmp = getLongArray(chainParam, "btmp");
if (!btmp) return NULL;
long chainLength = PyObject_Length(chainParam);
if (!chainLength) return NULL;
#define COST_TABLE(m, i, l) \
costTable[(m) * (chainLength + 1) * (chainLength + 1) + \
(i) * (chainLength + 1) + (l)]
double* costTable = (double*)calloc(
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(double));
#define BACK_PTR(m, i, l) \
backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \
(i) * (chainLength + 1) + (l)]
long* backPtr = (long*)calloc(
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chainLength; ++i)
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
(m >= x[i + 1] + xbar[i + 1] + ftmp[i]))
COST_TABLE(m, i, i) = ftime[i] + btime[i];
else
COST_TABLE(m, i, i) = INFINITY;
for (long m = 0; m <= mmax; ++m)
for (long d = 1; d <= chainLength; ++d) {
for (long i = 0; i <= chainLength - d; ++i) {
long idx = i + d;
long mmin = x[idx + 1] + x[i + 1] + ftmp[i];
if (idx > i + 1) {
long maxCostFWD = 0;
for (long j = i + 1; j < idx; j++) {
maxCostFWD = fmaxl(maxCostFWD, x[j] + x[j + 1] + ftmp[j]);
}
mmin = fmaxl(mmin, x[idx + 1] + maxCostFWD);
}
if ((m >= mmin)) {
long bestLeaf = -1;
double sumFw = 0;
double bestLeafCost = INFINITY;
for (long j = i + 1; j <= idx; ++j) {
sumFw += ftime[j - 1];
if (m >= x[j]) {
double cost = sumFw + COST_TABLE(m - x[j], j, idx) +
COST_TABLE(m, i, j - 1);
if (cost < bestLeafCost) {
bestLeafCost = cost;
bestLeaf = j;
}
}
}
double chainCost = INFINITY;
if (m >= xbar[i + 1])
chainCost =
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
if (bestLeafCost <= chainCost) {
COST_TABLE(m, i, idx) = bestLeafCost;
BACK_PTR(m, i, idx) = bestLeaf;
} else {
COST_TABLE(m, i, idx) = chainCost;
BACK_PTR(m, i, idx) = -1;
}
} else
COST_TABLE(m, i, idx) = INFINITY;
}
}
free(ftime);
free(btime);
free(x);
free(xbar);
free(ftmp);
free(btmp);
PyObject* pyCostTable = PyList_New(mmax + 1);
PyObject* pyBackPtr = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* pyCostTable_m = PyList_New(chainLength + 1);
PyList_SET_ITEM(pyCostTable, m, pyCostTable_m);
PyObject* pyBackPtr_m = PyList_New(chainLength + 1);
PyList_SET_ITEM(pyBackPtr, m, pyBackPtr_m);
for (long i = 0; i <= chainLength; ++i) {
PyObject* pyCostTable_m_i = PyDict_New();
PyList_SET_ITEM(pyCostTable_m, i, pyCostTable_m_i);
PyObject* pyBackPtr_m_i = PyDict_New();
PyList_SET_ITEM(pyBackPtr_m, i, pyBackPtr_m_i);
for (long l = i; l <= chainLength; ++l) {
PyObject* pyVar_l = PyLong_FromLong(l);
PyObject* pyCostTable_m_i_l = PyFloat_FromDouble(COST_TABLE(m, i, l));
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
Py_DECREF(pyCostTable_m_i_l);
PyObject* pyBackPtr_m_i_l;
if (BACK_PTR(m, i, l) < 0)
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
else
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
Py_DECREF(pyBackPtr_m_i_l);
Py_DECREF(pyVar_l);
}
}
}
free(costTable);
free(backPtr);
PyObject* result = PyTuple_Pack(2, pyCostTable, pyBackPtr);
Py_DECREF(pyCostTable);
Py_DECREF(pyBackPtr);
return result;
}
static PyMethodDef rotorMethods[] = {
{"compute_table", computeTable, METH_VARARGS,
"Compute the optimal table with the rotor algorithm."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
static struct PyModuleDef rotorModule = {
PyModuleDef_HEAD_INIT, "rotorc", /* name of module */
"A simple implementation of dynamic programming algorithm rotor with C in "
"https://hal.inria.fr/hal-02352969. Some code are adapted from "
"https://gitlab.inria.fr/hiepacs/rotor.", /* module documentation, may be
NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
rotorMethods};
PyMODINIT_FUNC PyInit_rotorc(void) { return PyModule_Create(&rotorModule); }
from copy import deepcopy
from typing import Any, Dict, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.profiler import (
activation_size,
calculate_bwd_time,
calculate_fwd_out,
calculate_fwd_time,
calculate_fwd_tmp,
)
from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
__all__ = ['CheckpointSolverRotor']
class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(self,
graph: Graph,
free_memory: float = -1,
cnode: List[str] = None,
memory_slots: int = 500,
optim_multiplier: float = 1.0):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
Usage:
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using ``CheckpointSolverRotor``:
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
free_memory (float, optional): Memory constraint for the solution, unit is byte.
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
"""
super().__init__(graph, free_memory, True, cnode, optim_multiplier)
self.memory_slots = memory_slots
# construct chain
unit = self.free_memory // self.memory_slots
self.chain = self._construct_chain(self.graph, self.node_list)
self.chain.discretize_all(unit)
self.cost_table = None
self.back_ptr = None
self.sequence = None
def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
"""Solve the checkpointing problem using rotor algorithm.
Args:
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
verbose (bool, optional): Print verbose information. Defaults to False.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
chain = self.chain
# compute cost table
if force_python:
self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots)
else:
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)
if verbose:
self.print_chain()
# backtrack
try:
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
self.back_ptr)
self._annotate_from_sequence(self.sequence, self.node_list)
except ValueError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.warning(f'Checkpoint solver failed: {e}')
raise ValueError
if verbose:
self.print_sequence()
return deepcopy(self.graph)
def print_chain(self):
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1):
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
self.chain.btmp[idx])
print(f'Chain = {self.chain}')
def print_sequence(self):
print(f'Sequence = {self.sequence}')
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
input_tensors = cls._extract_input(graph)
ftime, btime, ftmp, btmp = list(), list(), list(), list()
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]
for node in node_list:
node_info = cls._extract_node_info(node)
ftime.append(node_info[0])
btime.append(node_info[1])
x.append(node_info[2])
xbar.append(node_info[3])
ftmp.append(node_info[4])
btmp.append(node_info[5])
# currently we view loss backward temp as zero
btime.append(0)
btmp.append(0)
return Chain(ftime, btime, x, xbar, ftmp, btmp)
@classmethod
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
"""Extract node info from a list of nodes"""
xbar = 0
ftime = 0
btime = 0
fwd_mem_peak = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar += n.meta['fwd_mem_out']
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
btime += max(calculate_bwd_time(n), 1.0)
x = calculate_fwd_out(node[-1])
xbar = max(x, xbar)
ftmp = fwd_mem_peak - xbar
btmp = cls._extract_btmp(node)
return ftime, btime, x, xbar, ftmp, btmp
@staticmethod
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
"""Extract input tensors from a Graph"""
input_tensors = []
for node in graph.nodes:
if node.op == 'placeholder':
input_tensors.append(node.meta['fwd_out'])
return input_tensors
@staticmethod
def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`"""
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
"""Extract btmp from a list of nodes"""
def _extract_deps_size():
deps_size = 0
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
btmp = 0
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
deps[child] = float('-inf') # free
return btmp
@staticmethod
def _compute_table(chain: Chain, mmax: int) -> Tuple:
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
mmax (int): Maximum number of memory slots.
Returns:
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
of length j
"""
ftime = chain.ftime + [0.0]
btime = chain.btime
x = chain.x + [0]
xbar = chain.xbar + [0]
ftmp = chain.ftmp + [0]
btmp = chain.btmp + [0]
# Build table
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for m in range(mmax + 1):
for i in range(len(chain) + 1):
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
if m >= limit: # Equation (1)
cost_table[m][i][i] = ftime[i] + btime[i]
else:
cost_table[m][i][i] = float("inf")
# Compute everything
for m in range(mmax + 1):
for d in range(1, len(chain) + 1):
for i in range(len(chain) + 1 - d):
idx = i + d
mmin = x[idx + 1] + x[i + 1] + ftmp[i]
if idx > i + 1:
mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx)))
if m < mmin:
cost_table[m][i][idx] = float("inf")
else:
leaf_checkpoints = [(j,
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
for j in range(i + 1, idx + 1)
if m >= x[j]]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
best_leaf = None
if m >= xbar[i + 1]:
chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx]
else:
chain_checkpoint = float("inf")
if best_leaf and best_leaf[1] <= chain_checkpoint:
cost_table[m][i][idx] = best_leaf[1]
back_ptr[m][i][idx] = (False, best_leaf[0])
else:
cost_table[m][i][idx] = chain_checkpoint
back_ptr[m][i][idx] = (True,)
return cost_table, back_ptr
@staticmethod
def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
try:
from .rotorc import compute_table
# build module if module not found
except ModuleNotFoundError:
import os
import subprocess
import sys
logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
f"--build-lib={this_dir}"
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if result.wait() == 0:
logger.info("rotorc has been built!", ranks=[0])
from .rotorc import compute_table
else:
logger.warning("rotorc built failed! Using python version!", ranks=[0])
return CheckpointSolverRotor._compute_table(chain, mmax)
return compute_table(chain, mmax)
@staticmethod
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
back_ptr: List[Any]) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
lhs (int): The left index of the interval to backtrack.
rhs (int): The right index of the interval to backtrack.
budget (int): The memory budget for processing this interval.
cost_table (List[Any]): See ``._compute_table()`` for definitions
back_ptr (List[Any]): See ``._compute_table()`` for definitions
Raises:
ValueError: Can not process the chain.
Returns:
sequence (Sequence): The sequence of executing nodes with checkpoints.
"""
if budget <= 0:
raise ValueError(f"Can not process a chain with negative memory {budget}")
elif cost_table[budget][lhs][rhs] == float("inf"):
raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}")
sequence = Sequence()
if rhs == lhs:
if lhs == len(chain):
sequence += [Loss()]
else:
sequence += [ForwardEnable(lhs), Backward(lhs)]
return sequence
if back_ptr[budget][lhs][rhs][0]:
sequence += [
ForwardEnable(lhs),
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
back_ptr),
Backward(lhs),
]
else:
best_leaf = back_ptr[budget][lhs][rhs][1]
sequence += [ForwardCheck(lhs)]
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence += [
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
back_ptr),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
]
return sequence
@staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.
Args:
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
node_list (List[List[Node]]): The list of nodes to annotate.
"""
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
# forward annotation
for idx, op in enumerate(fwd_list, 0):
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(idx)
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
ckpt_region.append(idx)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
len(op_list[idx].meta['activation_checkpoint']))
import math
from abc import ABC
from typing import Any, Iterable, List
from torch.utils._pytree import tree_map
class Chain:
def __init__(self,
ftime: List[float],
btime: List[float],
x: List[int],
xbar: List[int],
ftmp: List[int],
btmp: List[int],
check_consistency: bool = True):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
Args:
ftime (List[float]): The forward time of each node.
btime (List[float]): The backward time of each node.
x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper.
xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper.
ftmp (List[int]): The temporary forward memory of each node.
btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget.
check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True.
"""
self.ftime = ftime
self.btime = btime
self.x = x
self.xbar = xbar
self.ftmp = ftmp
self.btmp = btmp
if check_consistency and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
and (len(self.xbar) == len(self) + 1))
def __repr__(self):
chain_list = []
for i in range(len(self)):
chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i]))
i = len(self)
chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i]))
return chain_list.__repr__()
def __len__(self):
return len(self.ftime)
def discretize_all(self, unit: int):
"""Discretize the chain into a list of chains according to unit size."""
discretizer = lambda val: math.ceil(val / unit)
self.x = tree_map(discretizer, self.x)
self.xbar = tree_map(discretizer, self.xbar)
self.ftmp = tree_map(discretizer, self.ftmp)
self.btmp = tree_map(discretizer, self.btmp)
class Operation(ABC):
name = "Op"
def __repr__(self) -> str:
return f"{self.name}_{self.index}"
def shift(self, value):
if type(self.index) is tuple:
self.index = tuple(x + value for x in self.index)
else:
self.index += value
class Forward(Operation):
name = "F"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
if chain is not None:
return chain.ftime[self.index]
else:
return 1
class ForwardEnable(Forward):
name = "Fe"
class ForwardNograd(Forward):
name = "Fn"
class ForwardCheck(Forward):
name = "CF"
class Forwards(Operation):
def __init__(self, start, end):
self.index = (start, end)
def __repr__(self):
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
def cost(self, chain: Chain):
if chain is not None:
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
else:
return (self.index[1] - self.index[0] + 1)
def isForward(op):
return type(op) is Forward or type(op) is Forwards
class Backward(Operation):
name = "B"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
if chain is not None:
return chain.btime[self.index]
else:
return 1
class Loss(Operation):
def __init__(self):
pass
def __repr__(self):
return "L"
def cost(self, chain):
return 0
class MemoryAccess(Operation):
name = "MA"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
return 0
class WriteMemory(MemoryAccess):
name = "WM"
class ReadMemory(MemoryAccess):
name = "RM"
class DiscardMemory(MemoryAccess):
name = "DM"
class Sequence(list):
def __init__(self):
super().__init__()
def __repr__(self):
return repr(self.list_operations())
def list_operations(self):
op_list = []
for x in self:
if isinstance(x, Operation):
op_list.append(x)
else:
assert isinstance(x, Sequence)
op_list += x.list_operations()
return op_list
from .meta_registry import *
from .metainfo import *
from .registry import meta_register
import operator
import torch
import torch.nn as nn
from ..tensor_shard.constants import *
# list of inplace module
INPLACE_MODULE = [nn.ReLU]
# list of inplace operations
INPLACE_OPS = [torch.flatten]
# list of operations that do not save forward activations
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]
from .activation import *
from .binary_elementwise_ops import *
from .conv import *
from .linear import *
from .norm import *
from .pooling import *
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
__all__ = ["relu_meta_info"]
@meta_register.register(torch.nn.ReLU)
def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.ReLU metainfo generator
The aten graph of torch.nn.ReLU is
graph():
%input_2 : [#users=1] = placeholder[target=placeholder](default=)
%relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {})
%threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
is_inplace = kwargs.get("inplace", False)
# construct input args for forward
fwd_in_args = [input_tensor]
# construct input args for backward
bwd_in_args = [output_tensor]
# calculate cost
# the fwd op with compute cost is relu.default
# the bwd op with compute cost is threshold_backward
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: the inplace ReLU don't have forward memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(
activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]),
parameter=0,
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
# NOTE: It might seems a little bit weird here, we just want to align it with the older version
# of MetaInfoProp. In the future we might modify this part to make it clearer.
fwd_in = []
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
from ..registry import meta_register
__all__ = ['binary_elementwise_meta_info']
@meta_register.register(BCAST_FUNC_OP)
def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta information generator for binary elementwise operations
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
this behavior, it is critical for better memory estimation.
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
# construct forward args for flop mapping
fwd_in_args = [opdata.data for opdata in input_op_data]
fwd_out_args = [output_op_data.data]
# calculate cost
# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost(
activation=activation_size(output_op_data.data),
parameter=param_mem_cost,
)
bwd_mem_cost = MemoryCost(
activation=activation_size(fwd_in_args),
parameter=param_mem_cost,
)
# total cost
total_mem_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['convnd_meta_info']
@meta_register.register(torch.nn.Conv1d)
@meta_register.register(torch.nn.Conv2d)
@meta_register.register(torch.nn.Conv3d)
@meta_register.register(torch.nn.functional.conv1d)
@meta_register.register(torch.nn.functional.conv2d)
@meta_register.register(torch.nn.functional.conv3d)
def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator
The atens graph of torch.nn.Convnd with bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%convolution_backward_default : [#users=3] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None, [None, None, None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
The atens graph of torch.nn.Convnd without bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None], [None, None], [None, None], None, [None, None], None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%convolution_backward_default : [#users=2] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None], [None, None], [None, None], None, [None, None], None, [None, None, None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
has_bias: bool = False
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
if len(args) == 4:
weight_tensors = [args[1].data, args[3].data]
else:
weight_tensors = [args[1].data]
# check if conv has bias
if len(weight_tensors) > 1:
has_bias = True
# bias tensor's shape only has one dimension
if len(weight_tensors[0].shape) == 1:
bias_tensor, weight_tensor = weight_tensors
else:
weight_tensor, bias_tensor = weight_tensors
else:
weight_tensor = weight_tensors[0]
# construct input args for forward
fwd_args = [None] * 9
# weight and input
fwd_args[0] = input_tensor
fwd_args[1] = weight_tensor
fwd_args[2] = bias_tensor if has_bias else None
# transpose indicator should be set to False
fwd_args[6] = False
# construct input args for backward
bwd_args = [None] * 11
# weight and input
bwd_args[0] = output_tensor
bwd_args[1] = input_tensor
bwd_args[2] = weight_tensor
bwd_args[-1] = [True, True, True] if has_bias else [True, True, False]
# calculate cost
# the fwd op with compute cost is convolution.default
# the bwd op with compute cost is convolution_backward.default
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(
activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(
activation=activation_size([input_tensor, weight_tensor, bias_tensor])
if has_bias else activation_size([input_tensor, weight_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
temp=0,
buffer=0)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['linear_meta_info']
@meta_register.register(torch.nn.functional.linear)
@meta_register.register(torch.nn.Linear)
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Linear & torch.nn.functional.linear meta info generator
NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator,
but we will hold the bias mechanism in the linear metainfo generator for future use.
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
%zeros_like_default : [#users=3] = call_function[target=torch.ops.aten.zeros_like.default](args = (%addmm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
%sum_dim_int_list : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%zeros_like_default, [None], None), kwargs = {})
%view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_dim_int_list, [None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%view_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
The one without bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%input_2, None), kwargs = {})
%zeros_like_default : [#users=2] = call_function[target=torch.ops.aten.zeros_like.default](args = (%mm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
%mm_default_2 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default_2,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
"""
has_bias: bool = False
input_tensor = args[0].data
output_tensor = args[2].data
if len(args) == 4:
weight_tensors = [args[1].data, args[3].data]
else:
weight_tensors = [args[1].data]
# process the dimension of input and output
if len(input_tensor.shape) > 2:
input_tensor: torch.Tensor
input_tensor = input_tensor.view(-1, input_tensor.shape[-1])
if len(output_tensor.shape) > 2:
output_tensor: torch.Tensor
output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
if len(weight_tensors) > 1:
has_bias = True
if len(weight_tensors[0].shape) == 2:
weight_tensor, bias_tensor = weight_tensors
else:
bias_tensor, weight_tensor = weight_tensors
else:
weight_tensor = weight_tensors[0]
if has_bias:
# calculate cost with bias
# the fwd op with compute cost is addmm
# the bwd op with compute cost is mm * 2 and sum.dim_IntList
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
else:
# calculate cost without bias
# the fwd op with compute cost is mm
# the bwd op with compute cost is mm * 2
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size(weight_tensor),
temp=0,
buffer=0)
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]),
parameter=activation_size(weight_tensor),
temp=0,
buffer=0)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['batchnormnd_meta_info']
@meta_register.register(torch.nn.BatchNorm1d)
@meta_register.register(torch.nn.BatchNorm2d)
@meta_register.register(torch.nn.BatchNorm3d)
def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""BatchNorm1d, BatchNorm2d, BatchNorm3d, meta info generator
The aten graph of BatchNorm2d is like
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%cudnn_batch_norm_default : [#users=4] = call_function[target=torch.ops.aten.cudnn_batch_norm.default](args = (%input_2, None, None, None, None, None, None, None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%cudnn_batch_norm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%cudnn_batch_norm_backward_default : [#users=3] = call_function[target=torch.ops.aten.cudnn_batch_norm_backward.default](args = (%detach_default, %zeros_like_default, None, None, None, %detach_default_1, %detach_default_2, None, %detach_default_3), kwargs = {})
%detach_default_4 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_5 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_4,), kwargs = {})
%detach_default_6 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_7 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_6,), kwargs = {})
%detach_default_8 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_9 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_8,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
mean_tensor = next(filter(lambda x: x.name == "running_mean", args)).data
var_tensor = next(filter(lambda x: x.name == "running_var", args)).data
num_batch = next(filter(lambda x: x.name == "num_batches_tracked", args)).data
# construct fwd args
# the fwd inputs are input, weight, bias, running_mean, running_var and some other args
# indicating the status of the module
# the fwd outputs are output, saved mean, saved inv std and num batches tracked
fwd_in_args = [input_tensor, weight_tensor, bias_tensor, mean_tensor, var_tensor, True, 0.1, 1e-5]
fwd_out_args = [output_tensor, mean_tensor, var_tensor, num_batch]
# construct bwd args
# the bwd inputs are upstream grad, input, weight, running_mean, running_var, saved mean,
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args = [
output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch
]
bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
# calculate cost
fwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = flop_mapping[torch.ops.aten.cudnn_batch_norm_backward.default](bwd_in_args, bwd_out_args)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
temp=0,
buffer=activation_size([mean_tensor, var_tensor]))
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
temp=activation_size([mean_tensor, var_tensor]),
buffer=activation_size([mean_tensor, var_tensor]))
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
__all__ = ["avgpool_meta_info", "maxpool_meta_info"]
@meta_register.register(torch.nn.AdaptiveAvgPool1d)
@meta_register.register(torch.nn.AdaptiveAvgPool2d)
@meta_register.register(torch.nn.AdaptiveAvgPool3d)
@meta_register.register(torch.flatten)
def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta info for AdaptiveAvgPool
The aten graph of AdaptiveAvgPool is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%_adaptive_avg_pool2d_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d.default](args = (%input_2, [None, None]), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%_adaptive_avg_pool2d_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%_adaptive_avg_pool2d_backward_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d_backward.default](args = (%zeros_like_default, %detach_default), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_adaptive_avg_pool2d_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
is_inplace = kwargs.get("inplace", False)
# construct forward args for flop mapping
fwd_in_args = [input_tensor]
fwd_out_args = [output_tensor]
# construct backward args for flop mapping
bwd_in_args = [output_tensor]
bwd_out_args = [input_tensor]
# calculate cost
# the fwd op with compute cost is _adaptive_avg_pool2d.default
# the bwd op with compute cost is _adaptive_avg_pool2d_backward.default
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d_backward.default](bwd_in_args, bwd_out_args)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor))
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor))
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
@meta_register.register(torch.nn.MaxPool1d)
@meta_register.register(torch.nn.MaxPool2d)
@meta_register.register(torch.nn.MaxPool3d)
def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta info for MaxPool
The aten graph of MaxPool is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%max_pool2d_with_indices_default : [#users=2] = call_function[target=torch.ops.aten.max_pool2d_with_indices.default](args = (%input_2, [None, None], [None, None]), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%max_pool2d_with_indices_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_default,), kwargs = {})
%max_pool2d_with_indices_backward_default : [#users=1] = call_function[target=torch.ops.aten.max_pool2d_with_indices_backward.default](args = (%zeros_like_default, %detach_default, [None, None], [None, None], [None, None], [None, None], None, %detach_default_1), kwargs = {})
%detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_backward_default,), kwargs = {})
%detach_default_3 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_2,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
# construct forward args for flop mapping
fwd_in_args = [input_tensor]
fwd_out_args = [output_tensor]
# construct backward args for flop mapping
bwd_in_args = [output_tensor]
bwd_out_args = [input_tensor]
# construct index matrix
index_matrix = torch.zeros_like(output_tensor, device="meta", dtype=torch.int64)
# calculate cost
# the fwd op with compute cost is max_pool2d_with_indices.default
# the bwd op with compute cost is max_pool2d_with_indices_backward.default
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = flop_mapping[torch.ops.aten.max_pool2d_with_indices_backward.default](bwd_in_args, bwd_out_args)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# NOTE: the index matrix will be discarded in backward phase
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix]))
# temp memory for backward is the index matrix to be discarded
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix),
temp=activation_size(index_matrix))
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
from typing import Callable, List
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
__all__ = ['MetaInfo']
class MetaInfo:
"""MetaInfo class
This class is used to store meta info based on sharding strategy and the given
target function.
"""
def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None:
# compute cost of forward and backward computation
self.compute_cost: TrainCycleItem
# compute memory cost of forward and backward phase
self.memory_cost: TrainCycleItem
# list of input tensors
self.fwd_in: List[torch.Tensor]
# list of buffer tensors
self.fwd_buffer: List[torch.Tensor]
# list of output tensors
self.fwd_out: List[torch.Tensor]
# sharding strategy
self._strategy = strategy
# target function
self._target = target
# compute metainfo if possible
if self._strategy is not None and self._target is not None:
self.compute_metainfo()
@property
def strategy(self) -> ShardingStrategy:
return self._strategy
@property
def target(self) -> Callable:
return self._target
@strategy.setter
def strategy(self, strategy: ShardingStrategy) -> None:
self._strategy = strategy
if self._strategy is not None and self._target is not None:
self.compute_metainfo()
@target.setter
def target(self, target: Callable) -> None:
self._target = target
if self._strategy is not None and self._target is not None:
self.compute_metainfo()
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Compute sharded opdata based on the given data and sharding spec.
"""
return OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
def compute_metainfo(self):
"""
Compute meta info based on sharding strategy and the given target function.
"""
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__):
# module
meta_func = meta_register.get(self._target.__class__)
# check whether the target in the list that we don't need to save activation
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
else:
# function
meta_func = meta_register.get(self._target)
# check whether the target in the list that we don't need to save activation
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
# construct args for meta_func
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
# construct kwargs
if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace}
elif self.target in INPLACE_OPS:
kwargs = {'inplace': True}
else:
kwargs = {'inplace': False}
# compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
# process corner case for NO_SAVE_ACTIVATION
if not save_fwd_in:
self.fwd_in = []
__all__ = ['Registry']
class Registry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
for element in source:
self.store[element] = func
else:
self.store[source] = func
return func
return wrapper
def get(self, source):
assert source in self.store, f'{source} not found in the {self.name} registry'
target = self.store[source]
return target
def has(self, source):
return source in self.store
meta_register = Registry('meta')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment