Commit 08f2920e authored by zhuwenwen's avatar zhuwenwen
Browse files

init colossalai, support dtk2304

parent da3f0934
Pipeline #237 failed with stages
in 0 seconds
#define PY_SSIZE_T_CLEAN
#include <Python.h>
long* PySequenceToLongArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
long* result = (long*)calloc(len + 1, sizeof(long));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyLong_AsLong(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
double* PySequenceToDoubleArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
double* result = (double*)calloc(len + 1, sizeof(double));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyFloat_AsDouble(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
long* getLongArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
long* result = PySequenceToLongArray(sequence);
Py_DECREF(sequence);
return result;
}
double* getDoubleArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
double* result = PySequenceToDoubleArray(sequence);
Py_DECREF(sequence);
return result;
}
static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweight");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweight");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweight");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweight");
if (!cbw) return NULL;
long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp");
if (!cbw) return NULL;
long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp");
if (!cbw) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double* opt = (double*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
#define WHAT(m, i, l) \
what[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
long* what = (long*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long));
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i)
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
OPT(m, i, i) = fw[i] + bw[i];
else
OPT(m, i, i) = INFINITY;
for (long m = 0; m <= mmax; ++m)
for (long d = 1; d <= chain_length; ++d) {
for (long i = 0; i <= chain_length - d; ++i) {
long idx = i + d;
long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i];
if (idx > i + 1) {
long maxCostFWD = 0;
for (long j = i + 1; j < idx; j++) {
maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]);
}
mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD);
}
if ((m >= mmin)) {
long bestLeaf = -1;
double sumFw = 0;
double bestLeafCost = INFINITY;
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
/// i+1
for (long j = i + 1; j <= idx; ++j) {
sumFw += fw[j - 1];
if (m >= cw[j]) {
double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1);
if (cost < bestLeafCost) {
bestLeafCost = cost;
bestLeaf = j;
}
}
}
double chainCost = INFINITY;
if (m >= cbw[i + 1])
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx);
if (bestLeafCost <= chainCost) {
OPT(m, i, idx) = bestLeafCost;
WHAT(m, i, idx) = bestLeaf;
} else {
OPT(m, i, idx) = chainCost;
WHAT(m, i, idx) = -1;
}
} else
OPT(m, i, idx) = INFINITY;
}
}
free(fw);
free(bw);
free(cw);
free(cbw);
free(fwd_tmp);
free(bwd_tmp);
PyObject* res_opt = PyList_New(mmax + 1);
PyObject* res_what = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_opt, m, res_opt_m);
PyObject* res_what_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_what, m, res_what_m);
for (long i = 0; i <= chain_length; ++i) {
PyObject* res_opt_m_i = PyDict_New();
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
PyObject* res_what_m_i = PyDict_New();
PyList_SET_ITEM(res_what_m, i, res_what_m_i);
for (long l = i; l <= chain_length; ++l) {
PyObject* res_l = PyLong_FromLong(l);
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
Py_DECREF(res_opt_m_i_l);
PyObject* res_what_m_i_l;
long what_m_i_l = WHAT(m, i, l);
if (what_m_i_l < 0)
res_what_m_i_l = Py_BuildValue("(O)", Py_True);
else
res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l);
PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l);
Py_DECREF(res_what_m_i_l);
Py_DECREF(res_l);
}
}
}
free(opt);
free(what);
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
Py_DECREF(res_opt);
Py_DECREF(res_what);
return result;
}
// long i = L - s, j = t - s, k = l - t
inline long floating_index_in_array(long m_factor, long m, long i, long j,
long k) {
return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j -
(j * (j - 1)) / 2 + k;
}
typedef struct {
long sp;
long r;
long tp;
} index_t;
static PyObject* floating_compute_table(PyObject* self, PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweigth");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweigth");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweigth");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweigth");
if (!cbw) return NULL;
long* fwd_tmp = getLongArray(chain_param, "fwd_tmp");
if (!fwd_tmp) return NULL;
long* bwd_tmp = getLongArray(chain_param, "bwd_tmp");
if (!bwd_tmp) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
const long m_factor =
(chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12;
// Defined for 0 <= s <= t <= l <= chain_length, for all m
#undef OPT
#define OPT(m, s, t, l) \
opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double));
#undef WHAT
#define WHAT(m, s, t, l) \
what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t));
double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double));
double total = 0;
for (long i = 0; i < chain_length; ++i) {
partialSumsFW[i] = total;
total += fw[i];
}
partialSumsFW[chain_length] = total;
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
OPT(m, i, i, i) = fw[i] + bw[i];
else
OPT(m, i, i, i) = INFINITY;
}
for (long m = 0; m <= mmax; ++m)
for (long d = 1; d <= chain_length; ++d) { // d = l - s
for (long s = 0; s <= chain_length - d; ++s) {
long l = s + d;
long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s];
long memNullSecond = 0;
for (long j = s + 1; j < l; ++j) {
long val = cw[j] + cw[j + 1] + fwd_tmp[j];
if (val > memNullSecond) memNullSecond = val;
}
for (long t = s; t <= l; ++t) {
double chainCost = INFINITY;
if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) &&
(m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) {
chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l);
}
double bestLeafCost = INFINITY;
index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1};
if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) {
for (long r = s; r <= t; ++r)
if (cw[s] <= cw[r])
for (long tp = t + 1; tp <= l; ++tp)
for (long sp = r + 1; sp <= tp; ++sp) {
long mp = m - cw[r] + cw[s];
assert(mp >= 0);
if (mp >= cw[sp]) {
double value = partialSumsFW[sp] - partialSumsFW[s] +
OPT(mp - cw[sp], sp, tp, l) +
OPT(mp, r, t, tp - 1);
if (value < bestLeafCost) {
bestLeafCost = value;
bestLeaf.sp = sp;
bestLeaf.r = r;
bestLeaf.tp = tp;
}
}
}
}
if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) {
OPT(m, s, t, l) = bestLeafCost;
WHAT(m, s, t, l).sp = bestLeaf.sp;
WHAT(m, s, t, l).r = bestLeaf.r;
WHAT(m, s, t, l).tp = bestLeaf.tp;
} else {
OPT(m, s, t, l) = chainCost;
WHAT(m, s, t, l).sp = -1;
}
}
}
}
free(fw);
free(bw);
free(cw);
free(cbw);
free(fwd_tmp);
free(bwd_tmp);
PyObject* res_opt = PyList_New(mmax + 1);
PyObject* res_what = PyList_New(mmax + 1);
// Convert the result into Python world
PyObject* true_tuple = Py_BuildValue("(O)", Py_True);
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyDict_New();
PyList_SET_ITEM(res_opt, m, res_opt_m);
PyObject* res_what_m = PyDict_New();
PyList_SET_ITEM(res_what, m, res_what_m);
for (long s = 0; s <= chain_length; ++s)
for (long t = s; t <= chain_length; ++t)
for (long l = t; l <= chain_length; ++l) {
PyObject* key = Py_BuildValue("(lll)", s, t, l);
PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l));
PyDict_SetItem(res_opt_m, key, value_opt);
PyObject* value_what = true_tuple;
index_t* idx_what = &WHAT(m, s, t, l);
if (idx_what->sp >= 0)
value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp,
idx_what->r, idx_what->tp);
PyDict_SetItem(res_what_m, key, value_what);
if (value_what != true_tuple) Py_DECREF(value_what);
Py_DECREF(key);
Py_DECREF(value_opt);
}
}
Py_DECREF(true_tuple);
free(opt);
free(what);
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
Py_DECREF(res_opt);
Py_DECREF(res_what);
return result;
}
static PyObject* griewank_heterogeneous_compute_table(PyObject* self,
PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweigth");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweigth");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweigth");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweigth");
if (!cbw) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#undef OPT
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double* opt = (double*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
// Compute partial sums
double* sumfw = (double*)calloc(chain_length, sizeof(double));
double* sumbw = (double*)calloc(chain_length + 1, sizeof(double));
double* sumsumfw = (double*)calloc(chain_length, sizeof(double));
double total = 0;
for (long i = 0; i < chain_length; ++i) {
total += fw[i];
sumfw[i] = total;
}
total = 0;
for (long i = 0; i < chain_length + 1; ++i) {
total += bw[i];
sumbw[i] = total;
}
total = 0;
for (long i = 0; i < chain_length; ++i) {
total += sumfw[i];
sumsumfw[i] = total;
}
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1]))
OPT(m, i, i) = bw[i];
else
OPT(m, i, i) = INFINITY;
if (i < chain_length) {
long maxC = fmaxl(cw[i], cw[i + 1]);
long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC);
if ((m >= cbw[i]) && (m >= cw[i] + maxCB))
OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1];
else
OPT(m, i, i + 1) = INFINITY;
}
}
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i + 2 <= chain_length; ++i) {
long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]);
long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]);
long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il;
for (long l = i + 2; l <= chain_length; ++l) {
maxCW_il = fmax(maxCW_il, cw[l + 1]);
maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il);
long mmin = fmaxl(mminCst, maxCostFWD);
if ((m >= mmin)) {
double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0);
noCheckpointCost +=
sumsumfw[l - 1] -
(i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0);
double valueCost = INFINITY;
if (m >= cw[i]) {
double sumFwds = 0;
for (long j = i + 1; j < l; ++j) {
sumFwds += fw[j - 1];
valueCost = fmin(
valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1));
}
}
OPT(m, i, l) = fmin(noCheckpointCost, valueCost);
} else
OPT(m, i, l) = INFINITY;
}
}
free(sumfw);
free(sumbw);
free(sumsumfw);
free(fw);
free(bw);
free(cw);
free(cbw);
PyObject* res_opt = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_opt, m, res_opt_m);
for (long i = 0; i <= chain_length; ++i) {
PyObject* res_opt_m_i = PyDict_New();
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
for (long l = i; l <= chain_length; ++l) {
PyObject* res_l = PyLong_FromLong(l - i);
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
Py_DECREF(res_opt_m_i_l);
Py_DECREF(res_l);
}
}
}
free(opt);
return res_opt;
}
static PyMethodDef dynamic_programs_methods[] = {
{"persistent_compute_table", persistent_compute_table, METH_VARARGS,
"Compute the optimal table with the persistent algorithm."},
{"floating_compute_table", floating_compute_table, METH_VARARGS,
"Compute the optimal table with the floating algorithm."},
{"griewank_heterogeneous_compute_table",
griewank_heterogeneous_compute_table, METH_VARARGS,
"Compute the optimal table for the Griewank Heterogeneous Model."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
static struct PyModuleDef dynamic_programs_module = {
PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
dynamic_programs_methods};
PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) {
return PyModule_Create(&dynamic_programs_module);
}
from typing import List, Any
from torch.fx import GraphModule, Node
from colossalai.fx.profiler import is_inplace
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
COPS = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in COPS
else:
return target.__name__ in COPS
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
"""Linearizing the graph
Args:
gm (GraphModule): GraphModule derived by tracing
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
We merge the inplace ops into the previous node.
"""
def _is_sink() -> bool:
"""Check if we can free all dependencies
Returns:
bool
"""
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
# make sure that item in cnode is valid
if cnode:
for name in cnode:
try:
assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \
f"common node {name} is not an input of the model"
except StopIteration:
raise ValueError(f"common node name {name} not in graph")
else:
cnode = []
deps = {}
linearized_nodes = []
region = []
for n in gm.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n._input_nodes:
if n_par.op != "placeholder" and n_par.name not in cnode:
deps[n_par] -= 1
region.append(n)
# if the node could free all dependencies in graph
# we could begin a new node
if _is_sink():
linearized_nodes.append(region)
region = []
# propagate common node attr if possible
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target):
cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
return linearized_nodes
import math
def _discretize(mem_unit, values):
return [math.ceil(value / mem_unit) for value in values]
class Chain:
def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True):
self.fweight = fw
self.bweight = bw
self.cweight = cw
self.cbweight = cbw
self.fwd_mem_tmp = ftmp
self.bwd_mem_tmp = btmp
self.length = len(fw)
if check and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
def __repr__(self):
chain_list = []
for i in range(self.length):
chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
self.bwd_mem_tmp[i]))
i = self.length
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
return chain_list.__repr__()
def _discretize(self, mem_unit):
self.cweight = _discretize(mem_unit, self.cweight)
self.cbweight = _discretize(mem_unit, self.cbweight)
self.fwd_mem_tmp = _discretize(mem_unit, self.fwd_mem_tmp)
self.bwd_mem_tmp = _discretize(mem_unit, self.bwd_mem_tmp)
class Operation:
def shift(self, value):
if type(self.index) is tuple:
self.index = tuple(x + value for x in self.index)
else:
self.index += value
class Offload(Operation):
def __init__(self, index, has_bar=False) -> None:
super().__init__()
self.index = index
self.name = "Off"
self.has_bar = has_bar
if self.has_bar:
self.name += "wBar"
def __repr__(self):
return f"{self.name}_{self.index}"
class Prefetch(Operation):
def __init__(self, index, has_bar=False) -> None:
super().__init__()
self.index = index
self.name = "Pre"
self.has_bar = has_bar
if self.has_bar:
self.name += "wBar"
def __repr__(self):
return f"{self.name}_{self.index}"
class Forward(Operation):
def __init__(self, index):
self.index = index
self.name = "F"
def __repr__(self):
return "{n}_{i}".format(n=self.name, i=self.index)
def cost(self, chain: Chain):
if chain is not None:
return chain.fweight[self.index]
else:
return 1
class ForwardEnable(Forward):
def __init__(self, index):
super().__init__(index)
self.name = "Fe"
class ForwardNograd(Forward):
def __init__(self, index):
super().__init__(index)
self.name = "Fn"
class ForwardCheck(Forward):
def __init__(self, index):
super().__init__(index)
self.name = "CF"
class Forwards(Operation):
def __init__(self, start, end):
self.index = (start, end)
def __repr__(self):
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
def cost(self, chain: Chain):
if chain is not None:
return sum(chain.fweight[self.index[0]:self.index[1] + 1])
else:
return (self.index[1] - self.index[0] + 1)
def isForward(op):
return type(op) is Forward or type(op) is Forwards
class Backward(Operation):
def __init__(self, index):
self.index = index
def __repr__(self):
return "B_{i}".format(i=self.index)
def cost(self, chain: Chain):
if chain is not None:
return chain.bweight[self.index]
else:
return 1
class Loss(Operation):
def __init__(self):
pass
def __repr__(self):
return "L"
def cost(self, chain):
return 0
class MemoryAccess(Operation):
def __init__(self, index):
self.index = index
def __repr__(self):
return "{n}_{i}".format(n=self.name, i=self.index)
def cost(self, chain: Chain):
return 0
class WriteMemory(MemoryAccess):
def __init__(self, index):
super().__init__(index)
self.name = "WM"
class ReadMemory(MemoryAccess):
def __init__(self, index):
super().__init__(index)
self.name = "RM"
class DiscardMemory(MemoryAccess):
def __init__(self, index):
super().__init__(index)
self.name = "DM"
class Function:
def __init__(self, name, *args):
self.name = name
self.args = args
self.str_args = ','.join(str(v) for v in self.args)
def __repr__(self):
return "{n}({args})".format(n=self.name, args=self.str_args)
class Sequence:
def __init__(self, function):
self.sequence = [] #List of Operation and Sequence
self.function = function #Description the function (name and parameters)
def __repr__(self):
return repr(self.list_operations())
def list_operations(self):
op_list = []
for x in self.sequence:
if isinstance(x, Operation):
op_list.append(x)
else:
assert isinstance(x, Sequence)
op_list += x.list_operations()
return op_list
def insert(self, operation):
self.sequence.append(operation)
def remove(self, operation_index):
del self.sequence[operation_index]
def insert_sequence(self, sequence):
self.sequence.append(sequence)
def shift(self, value):
for x in self.sequence:
x.shift(value)
return self
def remove_useless_write(self):
if self.sequence:
if isinstance(self.sequence[0], WriteMemory):
self.remove(0)
return self
def get_makespan(self, chain):
return sum(op.cost(chain) for op in self.list_operations())
def without_suffix(self):
ops = self.list_operations()
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
try:
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
except ValueError:
last_idx = -1
if last_idx == end_of_first_phase - 1:
return (self, None)
chain_length = ops[end_of_first_phase -
1].index ## Some assumption here about the sequence (finishes with Forward_L
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
for i in range(last_idx + 1):
result.insert(ops[i])
result.insert(Loss())
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
position = end_of_first_phase + 1 + (chain_length - i)
assert type(ops[position]) is Backward
assert ops[position].index == i
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
result.insert(ops[i])
return (result, start_of_fwd_enable_chain)
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import torch
import torch.fx
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_flatten
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module
@compatibility(is_backward_compatible=True)
class ConcreteInfoProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node with concrete tensor and record the memory
usage, execution time of forward and backward, and type of the result into
the corresponding node.
Usage:
BATCH_SIZE = 2
DIM_IN = 4
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
).cuda()
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
gm = symbolic_trace(model)
interp = ConcreteInfoProp(gm)
interp.run(input_sample)
print(interp.summary(unit='kb'))
output of above code is
Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- ---------
placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
call_module _0 0.0003993511199951172 s 0.00706791877746582 s False 0.50 KB 0.00 KB 0.03 KB 0.66 KB
call_module _1 6.29425048828125e-05 s 0.00018286705017089844 s False 0.50 KB 0.00 KB 0.12 KB 0.81 KB
output output 0.0 s 0.0 s True 0.00 KB 0.00 KB 0.00 KB 0.00 KB
Args:
module (GraphModule): The module to be executed
"""
_is_proped: bool = False
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""Customized run for ConcreteInfoProp
We need to store the device in self.device
Args:
*args: The arguments to the Module to run, in positional order
initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
This is a dict mapping `Node` to any value. This can be used, for example, to
pre-populate results for certain `Nodes` so as to do only partial evaluation within
the interpreter.
enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
process_outputs function first before using them.
Returns:
Any: The value returned from executing the Module
"""
flatten_args, _ = tree_flatten(args)
self.device = next(item for item in flatten_args if hasattr(item, "device")).device
return super().run(*args, initial_env, enable_io_processing)
@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
"""
Run a specific node ``n`` and return the result.
Calls into placeholder, get_attr, call_function,
call_method, call_module, or output depending
on ``node.op``
Args:
n (Node): The Node to execute
Returns:
Any: The result of executing ``n``
"""
self._is_proped = True
result, meta_info = super().run_node(n)
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
n.meta['type'] = type(result)
# retain the autograd graph
for param in self.module.parameters():
param.grad = None
return result
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
arguments passed to ``run`` and this method returns
next() on that iterator.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Returns:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
assert not isinstance(target, str)
return profile_function(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
return profile_method(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return profile_module(submod, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
return args[0], GraphInfo(save_fwd_in=True)
def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
record the shape and type of each node.
Args:
*args (Tensor): the sample input.
Returns:
Any: The value returned from executing the Module
"""
return super().run(*args)
def summary(self, unit: str = 'MB') -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
def mem_repr(mem: int) -> str:
unit_divisor_map = {
'kb': 1024,
'mb': 1024**2,
'gb': 1024**3,
'tb': 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
def time_repr(time: float):
return f"{time:,} s"
for node in self.module.graph.nodes:
node: Node
node_summaries.append([
node.op,
str(node),
time_repr(node.meta['fwd_time']),
time_repr(node.meta['bwd_time']),
node.meta['save_fwd_in'],
mem_repr(node.meta['fwd_mem_out']),
mem_repr(node.meta['fwd_mem_tmp']),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Forward time',
'Backward time',
'SAVE_FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
'BWD_TMP',
]
return tabulate(node_summaries, headers=headers, stralign='right')
import torch
from typing import List
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins
import operator
from copy import deepcopy
def apply(*args, **kwargs):
shape_consistency_manager = ShapeConsistencyManager()
return shape_consistency_manager.apply(*args, **kwargs)
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh):
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
setattr(node, 'best_strategy', strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec)
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec
# apply the sharding spec of parameters
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {})
setattr(target_module.weight, 'sharding_spec', origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.input_shardings[1]
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
apply(target_module.weight, target_weight_sharding_spec)
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
for index, node in enumerate(nodes):
target_sharding_specs = []
for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
target_sharding_spec = user_node.best_strategy.input_shardings[node_index]
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
# add above dicts into graph
for node in nodes:
if node.op != 'placeholder':
with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
break
return sharding_spec_convert_dict, origin_node_sharding_spec_dict
def shape_consistency_pass(gm: torch.fx.GraphModule):
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
input_dict_node = None
origin_dict_node = None
# mapping the node into the origin graph index
node_to_index_dict = {}
index = 0
for node in nodes:
if node.target == 'sharding_spec_convert_dict':
input_dict_node = node
continue
if node.target == 'origin_node_sharding_spec_dict':
origin_dict_node = node
continue
if not hasattr(node, 'best_strategy'):
continue
node_to_index_dict[node] = index
index += 1
assert input_dict_node is not None
# add shape consistency apply function into graph
for node in nodes:
if not hasattr(node, 'best_strategy'):
continue
with mod_graph.inserting_after(node):
origin_spec_node = mod_graph.create_node('call_function',
operator.getitem,
args=(origin_dict_node, node_to_index_dict[node]))
with mod_graph.inserting_after(origin_spec_node):
set_sharding_spec_node = mod_graph.create_node('call_function',
builtins.setattr,
args=(node, 'sharding_spec', origin_spec_node))
for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
with mod_graph.inserting_before(user_node):
input_specs_node = mod_graph.create_node('call_function',
operator.getitem,
args=(input_dict_node, node_to_index_dict[node]))
with mod_graph.inserting_before(user_node):
sharding_spec_node = mod_graph.create_node('call_function',
operator.getitem,
args=(input_specs_node, node_index))
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))
return gm
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Tuple
import torch
import torch.fx
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
from colossalai.fx.profiler import (
GraphInfo,
activation_size,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_tmp,
profile_function,
profile_method,
profile_module,
)
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
# TensorMetadata is a structure containing pertinent information
# about a tensor within a PyTorch program.
shape: torch.Size
dtype: torch.dtype
requires_grad: bool
stride: Tuple[int]
numel: int
is_tensor: bool
# TODO: we can add a list of sharding spec here, and record the sharding
# behaviour by appending sharding spec into list.
def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
"""
Extract a TensorMetadata NamedTuple describing `result`.
"""
shape = result.shape
dtype = result.dtype
requires_grad = result.requires_grad
stride = result.stride()
numel = result.numel()
is_tensor = True
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)
@compatibility(is_backward_compatible=True)
class MetaInfoProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node with meta tensor and
record the memory usage, FLOPs, and type of the result
into the corresponding node.
Usage:
BATCH_SIZE = 2
DIM_IN = 4
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
)
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
gm = symbolic_trace(model)
interp = MetaInfoProp(gm)
interp.run(input_sample)
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
# output of above code is
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
call_module _0 128 FLOPs 288 FLOPs 0.12 KB 0.00 KB 0.34 KB 0.00 KB
call_module _1 512 FLOPs 1,056 FLOPs 0.12 KB 0.00 KB 1.19 KB 0.00 KB
output output 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
Args:
module (GraphModule): The module to be executed
"""
_is_proped: bool = False
@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
"""
Run a specific node ``n`` and return the result.
Calls into placeholder, get_attr, call_function,
call_method, call_module, or output depending
on ``node.op``
Args:
n (Node): The Node to execute
Returns:
Any: The result of executing ``n``
"""
self._is_proped = True
result, meta_info = super().run_node(n)
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
return _extract_tensor_metadata(obj)
else:
return TensorMetadata(None, None, False, None, 0, False)
tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
n.meta['type'] = type(result)
# retain the autograd graph
for param in self.module.parameters():
param.grad = None
return result
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
arguments passed to ``run`` and this method returns
next() on that iterator.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Returns:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return profile_method(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return profile_module(submod)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
if hasattr(args[0], '_tensor'):
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
return args[0], GraphInfo(save_fwd_in=True)
def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
record the shape and type of each node.
Args:
*args (Tensor): the sample input.
Returns:
Any: The value returned from executing the Module
"""
return super().run(*args)
def summary(self, unit: str = 'MB') -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
def mem_repr(mem: int) -> str:
unit_divisor_map = {
'kb': 1024,
'mb': 1024**2,
'gb': 1024**3,
'tb': 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
def flops_repr(flop: int) -> str:
return f"{flop:,} FLOPs"
for node in self.module.graph.nodes:
node: Node
node_summaries.append([
node.op,
str(node),
flops_repr(node.meta['fwd_flop']),
flops_repr(node.meta['bwd_flop']),
mem_repr(calculate_fwd_in(node)),
mem_repr(calculate_fwd_out(node)),
mem_repr(calculate_fwd_tmp(node)),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Forward FLOPs',
'Backward FLOPs',
'FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
'BWD_TMP',
]
return tabulate(node_summaries, headers=headers, stralign='right')
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
"""
MetaInfo tracing API
Given a ``GraphModule`` and a sample input, this API will trace the MetaInfo of a single training cycle,
and annotate them on ``gm.graph``.
Uses:
>>> model = ...
>>> gm = symbolic_trace(model)
>>> args = ... # sample input to the ``GraphModule``
>>> metainfo_trace(gm, *args)
Args:
gm (torch.fx.GraphModule): The ``GraphModule`` to be annotated with MetaInfo.
verbose (bool, optional): Whether to show ``MetaInfoProp.summary()`. Defaults to False.
unit (str, optional): The unit of memory. Defaults to "MB".
Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
"""
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
interp = MetaInfoProp(gm.to(device))
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
interp.propagate(*args, **kwargs)
if verbose:
interp.summary(unit)
gm.to('cpu')
del interp
return gm
import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
from colossalai.fx.passes.meta_info_prop import TensorMetadata
import inspect
from typing import List
from colossalai.fx.passes.split_module import Partition
from colossalai.fx.passes.adding_split_node_pass import pipe_split, balanced_split_pass
from torch.fx.node import Node
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
'''
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
'''
mod_graph = gm.graph
valid_children_size = 0
valid_children = []
for node in mod_graph.nodes:
if node.op == "call_module":
valid_children_size += 1
valid_children.append(node.target)
if valid_children_size < pp_size:
# If valid children is not enough to shard, we will use balanced policy instead of uniform policy.
return balanced_split_pass(gm, pp_size)
accumulate_layer_amount = 0
list_of_part = partition_list
part_index = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if node.op == "call_module":
if node.target in valid_children:
accumulate_layer_amount += 1
if accumulate_layer_amount == list_of_part[part_index]:
part_index += 1
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
'''
This pass will be used in gpt2 test, only a part of changes may be added into
split_with_split_nodes_pass, and it will be deprecated in future.
'''
part_idx = 0
def eliminate_unused_placeholders(gm):
for node in gm.graph.nodes:
if node.op == 'placeholder':
if not len(node.users):
gm.graph.erase_node(node)
gm.recompile()
return gm
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
'''
This method is used to eliminate the outputs in previous partition which is unused in next partition.
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
'''
output_type = None
output_args = []
non_output_list = []
new_placeholder_list = []
for node in gm.graph.nodes:
if node.op == 'output':
if isinstance(node.args[0], (tuple, list)):
output_type = node.args[0].__class__
output_args.extend([n.name for n in node.args[0]])
else:
output_args.append(node.args[0].name)
rm_list = []
for name in output_args:
if next_partition_placeholders and name not in next_partition_placeholders:
rm_list.append(name)
for name in rm_list:
output_args.remove(name)
gm.graph.erase_node(node)
else:
non_output_list.append(node.name)
for name in next_partition_placeholders:
if name not in output_args:
output_args.append(name)
for name in output_args:
if name not in non_output_list:
gm.graph.placeholder(name)
# convert name to node for output_args
for index, name in enumerate(output_args):
for n in gm.graph.nodes:
if n.name == name:
output_args[index] = n
continue
# reorder the output args to make sure
# output args has same order as next partition placeholder
reorder_output_args = []
if next_partition_placeholders:
for name in next_partition_placeholders:
for node in output_args:
if node.name == name:
reorder_output_args.append(node)
continue
for node in gm.graph.nodes:
if node.op == 'placeholder':
new_placeholder_list.append(node.name)
if output_type is not None:
gm.graph.output(output_type(output_args))
else:
gm.graph.output(output_args)
gm.recompile()
return gm, new_placeholder_list
def split_callback(n: torch.fx.Node):
nonlocal part_idx
if (n.op, n.target) == ('call_function', pipe_split):
part_idx += 1
return part_idx
split_mod = split_module_for_gpt2_test(annotated_gm, None, split_callback)
split_submodules = []
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
if (node.op, node.target) == ('call_function', pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
submodules = list(split_mod.children())
placeholder_dict = {}
for submodule in submodules:
submodule = eliminate_unused_placeholders(submodule)
placeholder_dict[submodule] = []
submodules.reverse()
for index, submodule in enumerate(submodules):
if index == 0:
placeholder_list = []
else:
placeholder_list = placeholder_dict[submodules[index - 1]]
submodule, placeholder_dict[submodule] = refill_outputs_and_placeholders(submodule, placeholder_list)
submodule.recompile()
split_mod.recompile()
return split_mod, split_submodules
@compatibility(is_backward_compatible=True)
def split_module_for_gpt2_test(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int],
):
"""
This pass will be used in gpt2 pp performance test, only a part of changes may be added into
split_module, and it will be deprecated in future.
"""
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {}
def _node_with_all_tensor_element(node_metadata: Any) -> int:
"""
return whether node contains non-tensor element.
"""
all_tensor_node = True
if isinstance(node_metadata, TensorMetadata):
all_tensor_node = node_metadata.is_tensor and all_tensor_node
elif isinstance(node_metadata, dict):
value_list = [v for _, v in node_metadata.items()]
all_tensor_node += _node_with_all_tensor_element(value_list)
else:
for element in node_metadata:
all_tensor_node += _node_with_all_tensor_element(element)
return all_tensor_node
def _move_all_ancestors_into_partition(node, partition_name):
all_ancestors = set()
def _gen_all_ancestors_set(node):
all_ancestors.add(node)
for n in node.all_input_nodes:
if n in all_ancestors:
continue
_gen_all_ancestors_set(n)
_gen_all_ancestors_set(node)
for n in list(all_ancestors):
if n.op != 'placeholder' and n._fx_partition > partition_name:
n._fx_partition = partition_name
def record_cross_partition_use(def_node: torch.fx.node.Node,
use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name:
# if 'tensor_meta' in def_node.meta:
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
# _move_all_ancestors_into_partition(use_node, def_partition_name)
# node_process_list.extend(use_node.all_input_nodes)
# node_process_list.extend(list(use_node.users))
# node_process_list.append(use_node)
# return
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
def_partition.outputs.setdefault(def_node.name)
if use_partition_name is not None:
def_partition.partition_dependents.setdefault(use_partition_name)
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
node_process_list = list(m.graph.nodes)
# split nodes into parititons
while node_process_list:
node = node_process_list.pop(0)
orig_nodes[node.name] = node
if node.op in ["placeholder"]:
continue
if node.op == 'output':
# partition_name = str(split_callback(node))
# def _set_output_args_partition(n, partition_name):
# n._fx_partition = partition_name
# torch.fx.graph.map_arg(node.args[0], lambda n: _set_output_args_partition(n, partition_name))
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
continue
partition_name = str(split_callback(node))
# add node to partitions
partition = partitions.get(partition_name)
if partition is None:
partitions[partition_name] = partition = Partition(partition_name)
partition.node_names.append(node.name)
origin_partition_name = getattr(node, '_fx_partition', None)
if origin_partition_name is None:
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
for partition_name, partition in partitions.items():
if not len(partition.partitions_dependent_on):
root_partitions.append(partition_name)
# check partitions for circular dependencies and create topological partition ordering
sorted_partitions: List[str] = []
while root_partitions:
root_partition = root_partitions.pop()
sorted_partitions.append(root_partition)
for dependent in partitions[root_partition].partition_dependents:
partitions[dependent].partitions_dependent_on.pop(root_partition)
if not partitions[dependent].partitions_dependent_on:
root_partitions.append(dependent)
if len(sorted_partitions) != len(partitions):
raise RuntimeError("cycle exists between partitions!")
# add placeholders to parititons
for partition_name in sorted_partitions:
partition = partitions[partition_name]
for input in partition.inputs:
placeholder = partition.graph.placeholder(input)
placeholder.meta = orig_nodes[input].meta.copy()
partition.environment[orig_nodes[input]] = placeholder
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
if hasattr(node, '_fx_partition'):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
environment = partition.environment
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
if node.op not in ['call_module', 'get_attr']:
target = node.target
else:
target_atoms = node.target.split('.')
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!')
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
target = '_'.join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op,
target=target,
args=gathered_args,
kwargs=gathered_kwargs,
name=node.name)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
# Set up values to construct base module
base_mod_env: Dict[str, torch.fx.node.Node] = {}
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
# 1) Finish off submodule Graphs by setting corresponding outputs
# 2) Construct GraphModules for each submodule
# 3) Construct the base graph by emitting calls to those submodules in
# topological order
for partition_name in sorted_partitions:
partition = partitions[partition_name]
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
partition.graph) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
if len(partition.outputs) > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
import torch
import torch.nn as nn
import operator
from colossalai.tensor import ProcessGroup
from colossalai.tensor.distspec import ShardSpec
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
]
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
"""weight_split
split a nn.Parameter
Args:
weight (torch.nn.parameter.Parameter): a torch Parameter instance
dim (int): the dimension to be sharded along with
col_normal(bool): col shard with gather or not
Returns:
_type_: _description_
"""
if col_normal:
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_normal"))
else:
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs"))
return weight
def column_shard_linear_pass(gm: torch.fx.GraphModule):
# Split all the linear module with column shard. Currently for testing only.
mod_graph = gm.graph
for node in mod_graph.nodes:
if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
if isinstance(target_module, torch.nn.Linear):
target_module.weight = weight_split(target_module.weight, dim=0, col_normal=False)
if target_module.bias is not None:
target_module.bias.data = weight_split(target_module.bias.data, dim=0, col_normal=False)
gm.recompile()
return gm
def row_shard_linear_pass(gm: torch.fx.GraphModule):
# Split all the linear module with row shard. Currently for testing only.
mod_graph = gm.graph
for node in mod_graph.nodes:
if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
if isinstance(target_module, torch.nn.Linear):
target_module.weight = weight_split(target_module.weight, dim=-1, col_normal=False)
gm.recompile()
return gm
def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
"""
This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
"""
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
graph = graph_module.graph
world_size = process_group.world_size()
def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
# traverse the graph to look for consecutive linear layers
is_linear_module = False
if node.op == 'call_module':
# look for the linear layer
module = node.graph.owning_module.get_submodule(node.target)
if isinstance(module, nn.Linear):
is_linear_module = True
if start_tracking:
# when start_tracking = True
# it means the first linear has been found and the current module
# is the second linear
# set the current linear module to be row-sharded
annotation_record['row'] = module
for shard_type, module in annotation_record.items():
# add row sharding spec
if shard_type == 'row':
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D)
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', dist_spec)
setattr(module.weight, 'comp_spec', comp_spec)
elif shard_type == 'col':
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', weight_dist_spec)
setattr(module.weight, 'comp_spec', weight_comp_spec)
if module.bias is not None:
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False
setattr(module.bias, 'pg', process_group)
setattr(module.bias, 'dist_spec', bias_dist_spec)
setattr(module.bias, 'comp_spec', bias_comp_spec)
start_tracking = False
annotation_record.clear()
else:
# when start tracking = False
# it means the current layer is the first linear
# set the linear layer to be col-sharded
start_tracking = True
annotation_record['col'] = module
if start_tracking and not is_linear_module:
# check against the white list
# if non-element wise op is found, we reset the tracking
if node.op == 'call_module':
module = node.graph.owning_module.get_submodule(node.target)
if module.__class__ not in ELEMENTWISE_MODULE_OP:
start_tracking = False
elif node.op == 'call_function' or node.op == 'call_method':
if node.target not in ELEMENTWISE_FUNC_OP:
start_tracking = False
elif len(node.users.keys()) > 1:
start_tracking = False
if not start_tracking:
annotation_record.clear()
# stop tracking for consecutive linear when branch is found
# e.g.
# out1 = self.linear1(x)
# out2 = self.linear2(x)
# return out1+out2
next_nodes = list(node.users.keys())
if len(next_nodes) > 1:
start_tracking = False
annotation_record.clear()
# traverse
for node in next_nodes:
_traverse_and_annotate(node, start_tracking, annotation_record, world_size)
placeholder_node = list(graph.nodes)[0]
annotate_record = {}
_traverse_and_annotate(placeholder_node, False, annotate_record, world_size)
return graph_module
import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
import inspect
@compatibility(is_backward_compatible=True)
class Partition:
"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
"""
def __init__(self, name: str):
self.name: str = name
self.node_names: List[str] = []
self.inputs: Dict[str, None] = {}
self.outputs: Dict[str, None] = {}
self.partitions_dependent_on: Dict[str, None] = {}
self.partition_dependents: Dict[str, None] = {}
self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
self.environment: Dict[torch.fx.node.Node, torch.fx.node.Node] = {}
self.targets: Dict[str, Any] = {}
def __repr__(self) -> str:
return f"name: {self.name},\n" \
f" nodes: {self.node_names},\n" \
f" inputs: {self.inputs},\n" \
f" outputs: {self.outputs},\n" \
f" partitions depenent on: {self.partitions_dependent_on},\n" \
f" parition dependents: {self.partition_dependents}"
# Creates subgraphs out of main graph
@compatibility(is_backward_compatible=True)
def split_module(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int],
merge_output = False,
):
"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
Creates subgraphs out of main graph
Args:
m (GraphModule): Graph module to split
root_m (torch.nn.Module): root nn module. Not currently used. Included
because the root nn module is usually transformed via
torch.fx._symbolic_trace.symbolic_trace (see example below)
split_callback (Callable[[torch.fx.node.Node], int]): Callable function
that maps a given Node instance to a numeric partition identifier.
split_module will use this function as the policy for which operations
appear in which partitions in the output Module.
Returns:
GraphModule: the module after split.
Example:
This is a sample setup:
import torch
from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x, y):
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0)
return z + w
# symbolically trace model
my_module = MyModule()
my_module_traced = symbolic_trace(my_module)
# random mod partitioning
partition_counter = 0
NPARTITIONS = 3
def mod_partition(node: Node):
global partition_counter
partition = partition_counter % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS
return partition
# split module in module with submodules
module_with_submodules = split_module(
my_module_traced, my_module, mod_partition
)
Output looks like this. Original graph is broken into partitions
> print(module_with_submodules)
GraphModule(
(submod_0): GraphModule(
(linear): Linear(in_features=4, out_features=5, bias=True)
)
(submod_1): GraphModule(
(linear): Linear(in_features=4, out_features=5, bias=True)
)
(submod_2): GraphModule()
)
def forward(self, x, y):
param = self.param
submod_0 = self.submod_0(x, param, y); x = param = y = None
getitem = submod_0[0]
getitem_1 = submod_0[1]; submod_0 = None
submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None
getitem_2 = submod_1[0]
getitem_3 = submod_1[1]; submod_1 = None
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
return submod_2
Output of split module is the same as output of input traced module.
This is an example within a test setting:
> orig_out = my_module_traced(x, y)
> submodules_out = module_with_submodules(x, y)
> self.assertEqual(orig_out, submodules_out)
True
"""
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {}
def record_cross_partition_use(def_node: torch.fx.node.Node,
use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
def_partition.outputs.setdefault(def_node.name)
if use_partition_name is not None:
def_partition.partition_dependents.setdefault(use_partition_name)
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
def record_output(
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
def_partition.outputs.setdefault(def_node.name)
if use_partition_name is not None:
def_partition.partition_dependents.setdefault(use_partition_name)
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
use_partition.outputs.setdefault(def_node.name)
else:
if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.outputs.setdefault(def_node.name)
# split nodes into parititons
for node in m.graph.nodes:
orig_nodes[node.name] = node
if node.op in ["placeholder"]:
continue
if node.op == 'output':
if merge_output:
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
else:
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
continue
partition_name = str(split_callback(node))
# add node to partitions
partition = partitions.get(partition_name)
if partition is None:
partitions[partition_name] = partition = Partition(partition_name)
partition.node_names.append(node.name)
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
for partition_name, partition in partitions.items():
if not len(partition.partitions_dependent_on):
root_partitions.append(partition_name)
# check partitions for circular dependencies and create topological partition ordering
sorted_partitions: List[str] = []
while root_partitions:
root_partition = root_partitions.pop()
sorted_partitions.append(root_partition)
for dependent in partitions[root_partition].partition_dependents:
partitions[dependent].partitions_dependent_on.pop(root_partition)
if not partitions[dependent].partitions_dependent_on:
root_partitions.append(dependent)
if len(sorted_partitions) != len(partitions):
raise RuntimeError("cycle exists between partitions!")
# add placeholders to parititons
for partition_name in sorted_partitions:
partition = partitions[partition_name]
for input in partition.inputs:
placeholder = partition.graph.placeholder(input)
placeholder.meta = orig_nodes[input].meta.copy()
partition.environment[orig_nodes[input]] = placeholder
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
if hasattr(node, '_fx_partition'):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
environment = partition.environment
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
if node.op not in ['call_module', 'get_attr']:
target = node.target
else:
target_atoms = node.target.split('.')
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!')
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
target = '_'.join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op,
target=target,
args=gathered_args,
kwargs=gathered_kwargs)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
# Set up values to construct base module
base_mod_env: Dict[str, torch.fx.node.Node] = {}
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
# 1) Finish off submodule Graphs by setting corresponding outputs
# 2) Construct GraphModules for each submodule
# 3) Construct the base graph by emitting calls to those submodules in
# topological order
for partition_name in sorted_partitions:
partition = partitions[partition_name]
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
partition.graph) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
if len(partition.outputs) > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
for partition_name in sorted_partitions:
partition = partitions[partition_name]
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
return new_gm
import torch
from typing import Dict
from torch.fx.node import Node, map_arg
from torch.fx.graph import Graph
def get_comm_size(prev_partition, next_partition):
"""
Given two partitions (parent and child),
calculate the communication size between the two.
"""
# Keep tracking the communication size between parent and child
comm_size = 0
# Keep tracking all the counted node
visited_nodes = set()
# Go through all nodes in the child partition
# If a node has input nodes from the parent partition,
# the output size of those input nodes will be counted
# and added to comm_size
parent_node_names = [n.name for n in prev_partition.graph.nodes]
for node in next_partition.graph.nodes:
input_nodes: Dict[Node, None] = {}
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes:
comm_size += n.meta['tensor_meta'].numel
visited_nodes.add(n)
return comm_size
def get_leaf(graph: Graph):
"""
Given a graph, return leaf nodes of this graph.
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG.
"""
input_nodes: Dict[Node, None] = {}
for node in graph.nodes:
if node.op == 'output':
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
placeholder_nodes = []
for node in input_nodes.keys():
if node.op == 'placeholder':
placeholder_nodes.append(node)
for node in placeholder_nodes:
input_nodes.pop(node)
return list(input_nodes.keys())
def is_leaf(graph: Graph, node: Node):
return node in get_leaf(graph)
def get_top(graph: Graph):
"""
Given a graph, return top nodes of this graph.
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG.
"""
top_node_list = set()
for node in graph.nodes:
if node.op == 'output':
continue
is_top = False
def _get_top(node):
nonlocal is_top
if node.op == 'placeholder':
is_top = True
map_arg(node.args, lambda n: _get_top(n))
map_arg(node.kwargs, lambda n: _get_top(n))
if is_top:
top_node_list.add(node)
return list(top_node_list)
def is_top(graph: Graph, node: Node):
return node in get_top(graph)
def get_all_consumers(graph: Graph, node: Node):
"""
Given a graph and a node of this graph, return all consumers of the node.
Returns:
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
"""
consumer_list = []
for n in graph.nodes:
if node in n.all_input_nodes:
consumer_list.append(n)
return consumer_list
def assign_bfs_level_to_nodes(graph: Graph):
"""
Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes.
Example:
class MLP(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim)
self.linear2 = torch.nn.Linear(dim, dim)
self.linear3 = torch.nn.Linear(dim, dim)
self.linear4 = torch.nn.Linear(dim, dim)
self.linear5 = torch.nn.Linear(dim, dim)
def forward(self, x):
l1 = self.linear1(x)
l2 = self.linear2(x)
l3 = self.linear3(l1)
l4 = self.linear4(l2)
l5 = self.linear5(l3)
return l4, l5
model = MLP(4)
gm = symbolic_trace(model)
print(gm.graph)
assign_bfs_level_to_nodes(gm.graph)
for node in gm.graph.nodes:
if hasattr(node, 'bfs_level'):
print(node.name, node.bfs_level)
Output:
graph():
%x : [#users=2] = placeholder[target=x]
%linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
%linear2 : [#users=1] = call_module[target=linear2](args = (%x,), kwargs = {})
%linear3 : [#users=1] = call_module[target=linear3](args = (%linear1,), kwargs = {})
%linear4 : [#users=1] = call_module[target=linear4](args = (%linear2,), kwargs = {})
%linear5 : [#users=1] = call_module[target=linear5](args = (%linear3,), kwargs = {})
return (linear4, linear5)
linear1 0
linear2 0
linear3 1
linear4 1
linear5 2
"""
current_level = 0
nodes_to_process = []
top_nodes = get_top(graph)
for node in top_nodes:
node.bfs_level = current_level
nodes_to_process.extend(get_all_consumers(graph, node))
current_level += 1
while nodes_to_process:
new_process_list = []
for node in nodes_to_process:
if node.op == 'output':
continue
node.bfs_level = current_level
new_process_list.extend(get_all_consumers(graph, node))
nodes_to_process = new_process_list
current_level += 1
def get_node_module(node) -> torch.nn.Module:
"""
Find the module associated with the given node.
Args:
node (torch.fx.Node): a torch.fx.Node object in the fx computation graph
Returns:
torch.nn.Module: the module associated with the given node
"""
assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object'
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
module = node.graph.owning_module.get_submodule(node.target)
return module
from .._compatibility import is_compatible_with_meta
if is_compatible_with_meta():
from .opcount import flop_mapping
from .profiler import profile_function, profile_method, profile_module
from .shard_utils import (
calculate_bwd_time,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_time,
calculate_fwd_tmp,
)
from .tensor import MetaTensor
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .dataflow import GraphInfo
from .memory_utils import activation_size, is_inplace, parameter_size
import torch
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD']
aten = torch.ops.aten
ALIAS_ATEN = [
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]
INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]
CLONE_ATEN = [
aten.clone.default,
]
# See illustrations in
# https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/fx/profiler/constants.py
OUTPUT_SAVED_OPS = [
torch.nn.functional.relu,
torch.nn.functional.softmax,
]
OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from typing import Dict, List
from torch.fx import Graph, Node
from .._compatibility import compatibility
from .memory_utils import activation_size, is_inplace
class Phase(Enum):
FORWARD = 0
BACKWARD = 1
PLACEHOLDER = 2
@compatibility(is_backward_compatible=True)
@dataclass
class GraphInfo:
"""
GraphInfo is a dataclass for MetaInfo, which measures
the execution memory cost and FLOPs with `MetaTensor`.
The dataflow analysis is conducted on a single node of the FX graph.
============================================================================
-------------------------------
| Node |
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`.
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | \_____ | |
it is not saved for | | \ | |
backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node.
-------------------------------
============================================================================
Attributes:
fwd_flop (int): The forward FLOPs of a certain node.
fwd_time (float): The real forward time (s) of a certain node.
bwd_flop (int): The backward FLOPs of a certain node.
bwd_time (float): The real backward time (s) of a certain node.
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
fwd_in (List): See the above illustration.
fwd_tmp (List): See the above illustration.
fwd_out (List): See the above illustration.
fwd_mem_tmp (int): See the above illustration.
fwd_mem_out (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development
fwd_flop: int = 0
fwd_time: float = 0.0
bwd_flop: int = 0
bwd_time: float = 0.0
save_fwd_in: bool = False
fwd_in: List = field(default_factory=list)
fwd_tmp: List = field(default_factory=list)
fwd_out: List = field(default_factory=list)
fwd_mem_tmp: int = 0
fwd_mem_out: int = 0
bwd_mem_tmp: int = 0
bwd_mem_out: int = 0
def is_phase(n: Node, phase: Phase) -> bool:
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
return n.meta['phase'] == phase
@compatibility(is_backward_compatible=False)
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked for keyword `phase`.
Nodes should have attribute `out` indicating the output of each node.
============================================================================
Placeholder ----> p o <---- We need to keep track of grad out
|\________ |
↓ ↘|
f --------> b
|\ \_____ ↑
| \ ↘ /
f f ----> b <---- Not every forward result needs to be saved for backward
| \____ ↑
↘ ↘|
f ----> b <---- Backward can be freed as soon as it is required no more.
↘ ↗
l
=============================================================================
Args:
graph (Graph): The autograd graph with nodes marked for keyword `phase`.
Returns:
graph_info (GraphInfo): Meta information for the dataflow.
"""
def _peak_memory(deps: Dict[Node, int]):
peak_mem = 0
for k, v in deps.items():
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
peak_mem += activation_size(k.meta['saved_tensor'])
if v <= float('-inf') and is_phase(k, Phase.FORWARD):
peak_mem -= activation_size(k.meta['saved_tensor'])
return peak_mem
# deps is used to track all the memory dependencies of the graph.
deps = {}
graph_info = GraphInfo()
for n in graph.nodes:
n: Node
deps[n] = len(n.users)
# A forward tensor who is marked `save` but is also
# an input to `Phase.FORWARD` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# Any `fwd_mem_in` should be kept in memory even this function
# is checkpointed.
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.fwd_in += n.meta['saved_tensor']
if is_phase(n, Phase.FORWARD):
graph_info.fwd_tmp += n.meta['saved_tensor']
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
if deps[input_n] <= 0:
deps[input_n] = float('-inf')
return graph_info
from .profiler import profile_function, profile_method, profile_module
from .profiler_function import *
from .profiler_module import *
from .registry import meta_profiler_function, meta_profiler_module
from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub
import torch
__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
from dataclasses import dataclass
from typing import Any, Callable, Dict, Tuple
import torch
from torch.fx.node import Argument, Target
from ..._compatibility import compatibility
from ..memory_utils import activation_size
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
__all__ = ['profile_function', 'profile_module', 'profile_method']
# this is for compatibility use
@compatibility(is_backward_compatible=True)
@dataclass
class GraphInfo:
"""
GraphInfo is a dataclass for MetaInfo, which measures
the execution memory cost and FLOPs with `MetaTensor`.
The dataflow analysis is conducted on a single node of the FX graph.
============================================================================
-------------------------------
| Node |
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | |
it is not saved for | | | \ | |
backward. -------------------------------
============================================================================
Attributes:
fwd_flop (int): The forward FLOPs of a certain node
bwd_flop (int): The backward FLOPs of a certain node.
fwd_mem_in (int): See the above illustration.
fwd_mem_tmp (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
fwd_mem_tmp: int = 0
bwd_mem_tmp: int = 0
bwd_mem_out: int = 0
CALL_FUNCTION_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
CALL_MODULE_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Unfortunately, backward memory cost and FLOPs are estimated results.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), CALL_FUNCTION_MSG.format(target)
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
fwd_out = activation_size(out)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
else:
profiler = meta_profiler_function.get(target.__name__)
fwd_flop, _ = profiler(*args, **kwargs)
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
f.__name__ = target.__name__
func = target
return f
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
Warnings:
This is not fully implemented and you may follow the error message to debug.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
out = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
return out, GraphInfo(0, 0, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
return f
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module) -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if getattr(module, 'inplace', False):
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)
return out, GraphInfo(fwd_flop, fwd_flop * 2, fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
f.__name__ = module.__class__.__name__
func = module.forward
return f
from .activation_function import *
from .arithmetic import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .python_ops import *
from .torch_ops import *
from typing import Tuple
import torch
from ..registry import meta_profiler_function
# TODO: different activation has different FLOPs count, currently unused.
_multiplier = {
torch.nn.functional.relu: 1,
torch.nn.functional.prelu: 4,
torch.nn.functional.sigmoid: 4,
torch.nn.functional.tanh: 5,
torch.nn.functional.leaky_relu: 3,
torch.nn.functional.elu: 4,
torch.nn.functional.relu6: 2,
torch.nn.functional.gelu: 9,
torch.nn.functional.hardswish: 5,
torch.nn.functional.hardsigmoid: 4,
}
@meta_profiler_function.register(torch.nn.functional.leaky_relu)
@meta_profiler_function.register(torch.nn.functional.elu)
@meta_profiler_function.register(torch.nn.functional.gelu)
@meta_profiler_function.register(torch.nn.functional.relu6)
@meta_profiler_function.register(torch.nn.functional.prelu)
@meta_profiler_function.register(torch.nn.functional.relu)
@meta_profiler_function.register(torch.nn.functional.sigmoid)
@meta_profiler_function.register(torch.nn.functional.tanh)
@meta_profiler_function.register(torch.nn.functional.hardswish)
@meta_profiler_function.register(torch.nn.functional.hardsigmoid)
def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs
import operator
from functools import reduce
from typing import Any, Optional, Tuple, Union
import torch
from ..registry import meta_profiler_function
def _elementwise_flops_compute(input, other):
# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763
if not torch.is_tensor(input):
if torch.is_tensor(other):
return reduce(operator.mul, other.shape), 0
else:
return 1, 0
elif not torch.is_tensor(other):
return reduce(operator.mul, input.shape), 0
else:
dim_input = len(input.shape)
dim_other = len(other.shape)
max_dim = max(dim_input, dim_other)
final_shape = []
for i in range(max_dim):
in_i = input.shape[i] if i < dim_input else 1
ot_i = other.shape[i] if i < dim_other else 1
if in_i > ot_i:
final_shape.append(in_i)
else:
final_shape.append(ot_i)
flops = reduce(operator.mul, final_shape)
return flops, 0
@meta_profiler_function.register(torch.add)
@meta_profiler_function.register(torch.eq)
@meta_profiler_function.register(torch.sub)
@meta_profiler_function.register(torch.mul)
@meta_profiler_function.register(torch.floor_divide)
@meta_profiler_function.register('add') # for built-in op +
@meta_profiler_function.register('iadd') # for built-in op +=
@meta_profiler_function.register('eq') # for built-in op =
@meta_profiler_function.register('sub') # for built-in op -
@meta_profiler_function.register('isub') # for built-in op -=
@meta_profiler_function.register('mul') # for built-in op *
@meta_profiler_function.register('imul') # for built-in op *=
@meta_profiler_function.register('floordiv') # for built-in op //
@meta_profiler_function.register('ifloordiv') # for built-in op //=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)
@meta_profiler_function.register(torch.abs)
def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs
@meta_profiler_function.register(torch.matmul)
@meta_profiler_function.register('matmul') # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = reduce(operator.mul, input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs
@meta_profiler_function.register(torch.bmm)
def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = reduce(operator.mul, input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs
@meta_profiler_function.register(torch.var_mean)
def torch_var_mean(input: torch.Tensor,
dim: Union[int, Tuple[int, ...]],
unbiased: Optional[bool] = True,
keepdim: Optional[bool] = False,
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
assert out is None, 'saving to out is not supported yet'
flops = input.numel() * 3
macs = 0
return flops, macs
import torch
from typing import Optional
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.embedding)
def torch_nn_functional_embedding(
input: torch.Tensor,
weight: torch.Tensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
) -> torch.Tensor:
# F.embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
flops = 0
macs = 0
return flops, macs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment