Unverified Commit b42d3d28 authored by Super Daniel's avatar Super Daniel Committed by GitHub
Browse files

[fx] remove depreciated algorithms. (#2312) (#2313)

parent 55dcd305
from .ckpt_solver_chen import chen_greedy
from .linearize import linearize
from .ckpt_solver_rotor import solver_rotor
from .ckpt_solver_pofo import solver_pofo
from setuptools import setup, Extension
import os
this_dir = os.path.dirname(os.path.abspath(__file__))
ext_modules = [Extension(
'dynamic_programs_C_version',
sources=[os.path.join(this_dir, 'dynamic_programs.c')],
)]
setup(
name='rotor c extension',
version='0.1',
description='rotor c extension for faster dp computing',
ext_modules=ext_modules,
)
import math
from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule, Node
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
__all__ = ['chen_greedy']
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
"""
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
"""
def is_sink():
"""
If we can free all memories when executing a certain node, it is a sink.
"""
return not sum((v for k, v in deps.items()))
deps = {}
ckpt_nodes = []
for n in gm.graph.nodes:
for n_par in n._input_nodes:
deps[n_par] -= 1 # free memory and dependencies
# We can only put act_ckpt on these nodes
if n.op in CKPT_OP and is_sink():
ckpt_nodes.append(n)
deps[n] = len(n.users) # add dependencies for future executions
return ckpt_nodes
def chen_greedy(gm: GraphModule) -> GraphModule:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
model = resnet18()
input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
gm = chen_greedy(gm)
Args:
gm (GraphModule): The module to add checkpoints
"""
def grid_search(num_grids: int = 6) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
"""
_, b_approx = run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
ckpt_intv, b_approx = run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt_intv
return ckpt_opt
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt_nodes = _all_potential_ckpt_nodes(gm)
ckpt_intv = []
temp = 0
x = 0
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
n: Node
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += calculate_fwd_in(n)
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
return ckpt_intv, math.floor(math.sqrt(x * y))
gm.graph.lint() # make sure nodes are in topological order
ckpt = grid_search(num_grids=6)
node_list = list(gm.graph.nodes)
for i, seg in enumerate(ckpt):
for idx in range(*seg):
n = node_list[idx]
if n.op in CKPT_OP:
setattr(n, 'activation_checkpoint', i)
gm.recompile()
return gm
This diff is collapsed.
import math
import sys
from typing import List, Tuple
from torch.fx import Node
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
from colossalai.logging import get_dist_logger
from .linearize import linearize
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
# global vairable to indicate whether the solver is failed
SOLVER_FAILED = False
# this is the python compute table code from rotor
# https://gitlab.inria.fr/hiepacs/rotor
# paper link: https://hal.inria.fr/hal-02352969
def _compute_table(chain: Chain, mmax) -> Tuple:
"""Returns the optimal table: a tuple containing:
Opt[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
(False, j) if the optimal choice is a leaf checkpoint of length j
The computation uses dynamic programming"""
fw = chain.fweight + [0] ## forward time
bw = chain.bweight ## backward time, not used
cw = chain.cweight + [0] ## size of x (and of y)
cbw = chain.cbweight + [0] ## size of xbar
fwd_mem_tmp = chain.fwd_mem_tmp + [0]
bwd_mem_tmp = chain.bwd_mem_tmp + [0]
# Build table
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for m in range(mmax + 1):
for i in range(chain.length + 1):
#lmax-lmin = 0
limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
if m >= limit: ## Equation (1)
opt[m][i][i] = fw[i] + bw[i]
else:
opt[m][i][i] = float("inf")
# Compute everything
for m in range(mmax + 1):
for d in range(1, chain.length + 1):
for i in range(chain.length + 1 - d):
# for idx in range(i+1, chain.length + 1):
idx = i + d
mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
if idx > i + 1:
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
if m < mmin:
opt[m][i][idx] = float("inf")
else:
leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1])
for j in range(i + 1, idx + 1)
if m >= cw[j]]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
best_leaf = None
if m >= cbw[i + 1]:
chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx]
else:
chain_checkpoint = float("inf")
if best_leaf and best_leaf[1] <= chain_checkpoint:
opt[m][i][idx] = best_leaf[1]
what[m][i][idx] = (False, best_leaf[0])
else:
opt[m][i][idx] = chain_checkpoint
what[m][i][idx] = (True,)
return (opt, what)
def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
""" chain : the class describing the AC graph
lmin : index of the first forward to execute
lmax : upper bound index of the last forward to execute (not included)
cmem : number of available memory slots
Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]"""
if cmem <= 0:
raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem))
opt, what = opt_table
sequence = Sequence(Function("Persistent", lmax - lmin, cmem))
if opt[cmem][lmin][lmax] == float("inf"):
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
lmax=lmax,
cmem=cmem))
# set global indicater SOLVER_FAILED to True
global SOLVER_FAILED
SOLVER_FAILED = True
return sequence
if lmin == lmax:
if lmin == chain.length:
sequence.insert(Loss())
else:
sequence.insert(ForwardEnable(lmin))
sequence.insert(Backward(lmin))
return sequence
if what[cmem][lmin][lmax][0]:
sequence.insert(ForwardEnable(lmin))
sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table))
sequence.insert(Backward(lmin))
else:
j = what[cmem][lmin][lmax][1]
sequence.insert(ForwardCheck(lmin))
for k in range(lmin + 1, j):
sequence.insert(ForwardNograd(k))
sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table))
sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table))
return sequence
def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: xbar size, unit Byte
"""
xbar = 0
for n in node:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
return xbar
def _fwd_time(node: List[Node]) -> int:
"""Get the foward time of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: foward time, extimated by flops count
"""
fwd_time = 0
for n in node:
# minimum flop count is needed
fwd_time += max(n.meta['fwd_flop'], 1)
return fwd_time
def _bwd_time(node: List[Node]) -> int:
"""Get the backward time of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: backward time, extimated by flops count
"""
bwd_time = 0
for n in node:
# minimum flop count is needed
bwd_time += max(n.meta['bwd_flop'], 1)
return bwd_time
def _get_fwd_mem_tmp(node: List[Node]) -> int:
"""Get the forward temp memory of a node
This could be done by subtracting the saved activation from all output of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: forward temp memory, unit Byte
"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
def _get_bwd_mem_tmp(node: List[Node]) -> int:
"""Get the backward temp memory of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: backward temp memory, unit Byte
"""
def _get_deps_size():
deps_size = 0
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
bwd_mem_tmp = 0
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
deps[child] = float('-inf') # free
return bwd_mem_tmp
def _construct_chain(node_list: List[List[Node]], input) -> Chain:
fwd_time = []
bwd_time = []
xbar_sizes = [activation_size(input)]
x_sizes = [activation_size(input)]
tmp_fwd = []
tmp_bwd = []
for idx, node in enumerate(node_list):
fwd_time.append(_fwd_time(node))
bwd_time.append(_bwd_time(node))
x_sizes.append(calculate_fwd_out(node[-1]))
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_fwd.append(_get_fwd_mem_tmp(node))
tmp_bwd.append(_get_bwd_mem_tmp(node))
bwd_time.append(0)
# currently we view loss backward temp as zero
tmp_bwd.append(0)
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
# forward annotation
for idx, op in enumerate(fwd_list, 0):
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(idx)
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = [idx]
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
ckpt_region.append(idx)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
def solver_rotor(gm: ColoGraphModule,
data,
mem_limit: int,
mem_slots: int = 500,
cnode: List[str] = None,
eps: float = 0.0,
force_python: bool = False) -> ColoGraphModule:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp.
data (torch.Tensor): input data.
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.0
force_python (bool): force to use python version of dynamic programs
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
"""
# try to import C version solver if force_python is not set
logger = get_dist_logger()
if not force_python:
try:
from .dynamic_programs_C_version import persistent_compute_table
CVERSION = True
# build module if module not found
except ModuleNotFoundError:
import os
import subprocess
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
f"--build-lib={this_dir}"
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if result.wait() == 0:
logger.info("dynamic_programs_C_version has been built!", ranks=[0])
from .dynamic_programs_C_version import persistent_compute_table
CVERSION = True
else:
logger.info("dynamic_programs_C_version built failed! Using python version!", ranks=[0])
CVERSION = False
else:
CVERSION = False
# check if metainfoprop is done
if any(len(node.meta) == 0 for node in gm.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!")
# linearize the graph
node_list = linearize(gm, cnode)
# construct chain
mem_unit = mem_limit * (1.0 - eps) // mem_slots
chain: Chain = _construct_chain(node_list, data)
chain._discretize(mem_unit)
# use C version if possible
if CVERSION and not force_python:
logger.info("Using C version rotor solver!", ranks=[0])
opt_table = persistent_compute_table(chain, mem_slots)
else:
opt_table = _compute_table(chain, mem_slots)
logger.info("Using python version rotor solver!", ranks=[0])
# found sequence
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
# if solver failed, we don't need to annotate the graph
if not SOLVER_FAILED:
_annotate_from_sequence(sequence, node_list)
# set __sequence__ attribute to GraphModule
if SOLVER_FAILED:
setattr(gm, "__sequence__", None)
else:
setattr(gm, "__sequence__", sequence)
# set __opttable__ attribute to GraphModule
setattr(gm, "__opttable__", opt_table[0])
gm.recompile()
return gm
#define PY_SSIZE_T_CLEAN
#include <Python.h>
long* PySequenceToLongArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
long* result = (long*)calloc(len + 1, sizeof(long));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyLong_AsLong(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
double* PySequenceToDoubleArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
double* result = (double*)calloc(len + 1, sizeof(double));
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(pylist, i);
result[i] = PyFloat_AsDouble(item);
Py_DECREF(item);
}
result[len] = 0;
return result;
}
long* getLongArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
long* result = PySequenceToLongArray(sequence);
Py_DECREF(sequence);
return result;
}
double* getDoubleArray(PyObject* container, const char* attributeName) {
PyObject* sequence = PyObject_GetAttrString(container, attributeName);
double* result = PySequenceToDoubleArray(sequence);
Py_DECREF(sequence);
return result;
}
static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweight");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweight");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweight");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweight");
if (!cbw) return NULL;
long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp");
if (!cbw) return NULL;
long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp");
if (!cbw) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double* opt = (double*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
#define WHAT(m, i, l) \
what[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
long* what = (long*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long));
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i)
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
OPT(m, i, i) = fw[i] + bw[i];
else
OPT(m, i, i) = INFINITY;
for (long m = 0; m <= mmax; ++m)
for (long d = 1; d <= chain_length; ++d) {
for (long i = 0; i <= chain_length - d; ++i) {
long idx = i + d;
long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i];
if (idx > i + 1) {
long maxCostFWD = 0;
for (long j = i + 1; j < idx; j++) {
maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]);
}
mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD);
}
if ((m >= mmin)) {
long bestLeaf = -1;
double sumFw = 0;
double bestLeafCost = INFINITY;
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
/// i+1
for (long j = i + 1; j <= idx; ++j) {
sumFw += fw[j - 1];
if (m >= cw[j]) {
double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1);
if (cost < bestLeafCost) {
bestLeafCost = cost;
bestLeaf = j;
}
}
}
double chainCost = INFINITY;
if (m >= cbw[i + 1])
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx);
if (bestLeafCost <= chainCost) {
OPT(m, i, idx) = bestLeafCost;
WHAT(m, i, idx) = bestLeaf;
} else {
OPT(m, i, idx) = chainCost;
WHAT(m, i, idx) = -1;
}
} else
OPT(m, i, idx) = INFINITY;
}
}
free(fw);
free(bw);
free(cw);
free(cbw);
free(fwd_tmp);
free(bwd_tmp);
PyObject* res_opt = PyList_New(mmax + 1);
PyObject* res_what = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_opt, m, res_opt_m);
PyObject* res_what_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_what, m, res_what_m);
for (long i = 0; i <= chain_length; ++i) {
PyObject* res_opt_m_i = PyDict_New();
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
PyObject* res_what_m_i = PyDict_New();
PyList_SET_ITEM(res_what_m, i, res_what_m_i);
for (long l = i; l <= chain_length; ++l) {
PyObject* res_l = PyLong_FromLong(l);
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
Py_DECREF(res_opt_m_i_l);
PyObject* res_what_m_i_l;
long what_m_i_l = WHAT(m, i, l);
if (what_m_i_l < 0)
res_what_m_i_l = Py_BuildValue("(O)", Py_True);
else
res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l);
PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l);
Py_DECREF(res_what_m_i_l);
Py_DECREF(res_l);
}
}
}
free(opt);
free(what);
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
Py_DECREF(res_opt);
Py_DECREF(res_what);
return result;
}
// long i = L - s, j = t - s, k = l - t
inline long floating_index_in_array(long m_factor, long m, long i, long j,
long k) {
return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j -
(j * (j - 1)) / 2 + k;
}
typedef struct {
long sp;
long r;
long tp;
} index_t;
static PyObject* floating_compute_table(PyObject* self, PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweigth");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweigth");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweigth");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweigth");
if (!cbw) return NULL;
long* fwd_tmp = getLongArray(chain_param, "fwd_tmp");
if (!fwd_tmp) return NULL;
long* bwd_tmp = getLongArray(chain_param, "bwd_tmp");
if (!bwd_tmp) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
const long m_factor =
(chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12;
// Defined for 0 <= s <= t <= l <= chain_length, for all m
#undef OPT
#define OPT(m, s, t, l) \
opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double));
#undef WHAT
#define WHAT(m, s, t, l) \
what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t));
double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double));
double total = 0;
for (long i = 0; i < chain_length; ++i) {
partialSumsFW[i] = total;
total += fw[i];
}
partialSumsFW[chain_length] = total;
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
(m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
OPT(m, i, i, i) = fw[i] + bw[i];
else
OPT(m, i, i, i) = INFINITY;
}
for (long m = 0; m <= mmax; ++m)
for (long d = 1; d <= chain_length; ++d) { // d = l - s
for (long s = 0; s <= chain_length - d; ++s) {
long l = s + d;
long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s];
long memNullSecond = 0;
for (long j = s + 1; j < l; ++j) {
long val = cw[j] + cw[j + 1] + fwd_tmp[j];
if (val > memNullSecond) memNullSecond = val;
}
for (long t = s; t <= l; ++t) {
double chainCost = INFINITY;
if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) &&
(m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) {
chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l);
}
double bestLeafCost = INFINITY;
index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1};
if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) {
for (long r = s; r <= t; ++r)
if (cw[s] <= cw[r])
for (long tp = t + 1; tp <= l; ++tp)
for (long sp = r + 1; sp <= tp; ++sp) {
long mp = m - cw[r] + cw[s];
assert(mp >= 0);
if (mp >= cw[sp]) {
double value = partialSumsFW[sp] - partialSumsFW[s] +
OPT(mp - cw[sp], sp, tp, l) +
OPT(mp, r, t, tp - 1);
if (value < bestLeafCost) {
bestLeafCost = value;
bestLeaf.sp = sp;
bestLeaf.r = r;
bestLeaf.tp = tp;
}
}
}
}
if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) {
OPT(m, s, t, l) = bestLeafCost;
WHAT(m, s, t, l).sp = bestLeaf.sp;
WHAT(m, s, t, l).r = bestLeaf.r;
WHAT(m, s, t, l).tp = bestLeaf.tp;
} else {
OPT(m, s, t, l) = chainCost;
WHAT(m, s, t, l).sp = -1;
}
}
}
}
free(fw);
free(bw);
free(cw);
free(cbw);
free(fwd_tmp);
free(bwd_tmp);
PyObject* res_opt = PyList_New(mmax + 1);
PyObject* res_what = PyList_New(mmax + 1);
// Convert the result into Python world
PyObject* true_tuple = Py_BuildValue("(O)", Py_True);
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyDict_New();
PyList_SET_ITEM(res_opt, m, res_opt_m);
PyObject* res_what_m = PyDict_New();
PyList_SET_ITEM(res_what, m, res_what_m);
for (long s = 0; s <= chain_length; ++s)
for (long t = s; t <= chain_length; ++t)
for (long l = t; l <= chain_length; ++l) {
PyObject* key = Py_BuildValue("(lll)", s, t, l);
PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l));
PyDict_SetItem(res_opt_m, key, value_opt);
PyObject* value_what = true_tuple;
index_t* idx_what = &WHAT(m, s, t, l);
if (idx_what->sp >= 0)
value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp,
idx_what->r, idx_what->tp);
PyDict_SetItem(res_what_m, key, value_what);
if (value_what != true_tuple) Py_DECREF(value_what);
Py_DECREF(key);
Py_DECREF(value_opt);
}
}
Py_DECREF(true_tuple);
free(opt);
free(what);
PyObject* result = PyTuple_Pack(2, res_opt, res_what);
Py_DECREF(res_opt);
Py_DECREF(res_what);
return result;
}
static PyObject* griewank_heterogeneous_compute_table(PyObject* self,
PyObject* args) {
PyObject* chain_param;
int mmax;
if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
double* fw = getDoubleArray(chain_param, "fweigth");
if (!fw) return NULL;
double* bw = getDoubleArray(chain_param, "bweigth");
if (!bw) return NULL;
long* cw = getLongArray(chain_param, "cweigth");
if (!cw) return NULL;
long* cbw = getLongArray(chain_param, "cbweigth");
if (!cbw) return NULL;
PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
if (!chain_length_param) return NULL;
long chain_length = PyLong_AsLong(chain_length_param);
Py_DECREF(chain_length_param);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#undef OPT
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double* opt = (double*)calloc(
(mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
// Compute partial sums
double* sumfw = (double*)calloc(chain_length, sizeof(double));
double* sumbw = (double*)calloc(chain_length + 1, sizeof(double));
double* sumsumfw = (double*)calloc(chain_length, sizeof(double));
double total = 0;
for (long i = 0; i < chain_length; ++i) {
total += fw[i];
sumfw[i] = total;
}
total = 0;
for (long i = 0; i < chain_length + 1; ++i) {
total += bw[i];
sumbw[i] = total;
}
total = 0;
for (long i = 0; i < chain_length; ++i) {
total += sumfw[i];
sumsumfw[i] = total;
}
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
// TODO: Can be optimized to remove the IF by reordering loops
if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1]))
OPT(m, i, i) = bw[i];
else
OPT(m, i, i) = INFINITY;
if (i < chain_length) {
long maxC = fmaxl(cw[i], cw[i + 1]);
long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC);
if ((m >= cbw[i]) && (m >= cw[i] + maxCB))
OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1];
else
OPT(m, i, i + 1) = INFINITY;
}
}
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i + 2 <= chain_length; ++i) {
long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]);
long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]);
long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il;
for (long l = i + 2; l <= chain_length; ++l) {
maxCW_il = fmax(maxCW_il, cw[l + 1]);
maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il);
long mmin = fmaxl(mminCst, maxCostFWD);
if ((m >= mmin)) {
double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0);
noCheckpointCost +=
sumsumfw[l - 1] -
(i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0);
double valueCost = INFINITY;
if (m >= cw[i]) {
double sumFwds = 0;
for (long j = i + 1; j < l; ++j) {
sumFwds += fw[j - 1];
valueCost = fmin(
valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1));
}
}
OPT(m, i, l) = fmin(noCheckpointCost, valueCost);
} else
OPT(m, i, l) = INFINITY;
}
}
free(sumfw);
free(sumbw);
free(sumsumfw);
free(fw);
free(bw);
free(cw);
free(cbw);
PyObject* res_opt = PyList_New(mmax + 1);
// Convert the result into Python world
for (long m = 0; m <= mmax; ++m) {
PyObject* res_opt_m = PyList_New(chain_length + 1);
PyList_SET_ITEM(res_opt, m, res_opt_m);
for (long i = 0; i <= chain_length; ++i) {
PyObject* res_opt_m_i = PyDict_New();
PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
for (long l = i; l <= chain_length; ++l) {
PyObject* res_l = PyLong_FromLong(l - i);
PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
Py_DECREF(res_opt_m_i_l);
Py_DECREF(res_l);
}
}
}
free(opt);
return res_opt;
}
static PyMethodDef dynamic_programs_methods[] = {
{"persistent_compute_table", persistent_compute_table, METH_VARARGS,
"Compute the optimal table with the persistent algorithm."},
{"floating_compute_table", floating_compute_table, METH_VARARGS,
"Compute the optimal table with the floating algorithm."},
{"griewank_heterogeneous_compute_table",
griewank_heterogeneous_compute_table, METH_VARARGS,
"Compute the optimal table for the Griewank Heterogeneous Model."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
static struct PyModuleDef dynamic_programs_module = {
PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
dynamic_programs_methods};
PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) {
return PyModule_Create(&dynamic_programs_module);
}
from typing import List, Any
from torch.fx import GraphModule, Node
from colossalai.fx.profiler import is_inplace
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
COPS = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in COPS
else:
return target.__name__ in COPS
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
"""Linearizing the graph
Args:
gm (GraphModule): GraphModule derived by tracing
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
We merge the inplace ops into the previous node.
"""
def _is_sink() -> bool:
"""Check if we can free all dependencies
Returns:
bool
"""
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
# make sure that item in cnode is valid
if cnode:
for name in cnode:
try:
assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \
f"common node {name} is not an input of the model"
except StopIteration:
raise ValueError(f"common node name {name} not in graph")
else:
cnode = []
deps = {}
linearized_nodes = []
region = []
for n in gm.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n._input_nodes:
if n_par.op != "placeholder" and n_par.name not in cnode:
deps[n_par] -= 1
region.append(n)
# if the node could free all dependencies in graph
# we could begin a new node
if _is_sink():
linearized_nodes.append(region)
region = []
# propagate common node attr if possible
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target):
cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
return linearized_nodes
import math
def _discretize(mem_unit, values):
return [math.ceil(value / mem_unit) for value in values]
class Chain:
def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True):
self.fweight = fw
self.bweight = bw
self.cweight = cw
self.cbweight = cbw
self.fwd_mem_tmp = ftmp
self.bwd_mem_tmp = btmp
self.length = len(fw)
if check and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
def __repr__(self):
chain_list = []
for i in range(self.length):
chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
self.bwd_mem_tmp[i]))
i = self.length
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
return chain_list.__repr__()
def _discretize(self, mem_unit):
self.cweight = _discretize(mem_unit, self.cweight)
self.cbweight = _discretize(mem_unit, self.cbweight)
self.fwd_mem_tmp = _discretize(mem_unit, self.fwd_mem_tmp)
self.bwd_mem_tmp = _discretize(mem_unit, self.bwd_mem_tmp)
class Operation:
def shift(self, value):
if type(self.index) is tuple:
self.index = tuple(x + value for x in self.index)
else:
self.index += value
class Offload(Operation):
def __init__(self, index, has_bar=False) -> None:
super().__init__()
self.index = index
self.name = "Off"
self.has_bar = has_bar
if self.has_bar:
self.name += "wBar"
def __repr__(self):
return f"{self.name}_{self.index}"
class Prefetch(Operation):
def __init__(self, index, has_bar=False) -> None:
super().__init__()
self.index = index
self.name = "Pre"
self.has_bar = has_bar
if self.has_bar:
self.name += "wBar"
def __repr__(self):
return f"{self.name}_{self.index}"
class Forward(Operation):
def __init__(self, index):
self.index = index
self.name = "F"
def __repr__(self):
return "{n}_{i}".format(n=self.name, i=self.index)
def cost(self, chain: Chain):
if chain is not None:
return chain.fweight[self.index]
else:
return 1
class ForwardEnable(Forward):
def __init__(self, index):
super().__init__(index)
self.name = "Fe"
class ForwardNograd(Forward):
def __init__(self, index):
super().__init__(index)
self.name = "Fn"
class ForwardCheck(Forward):
def __init__(self, index):
super().__init__(index)
self.name = "CF"
class Forwards(Operation):
def __init__(self, start, end):
self.index = (start, end)
def __repr__(self):
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
def cost(self, chain: Chain):
if chain is not None:
return sum(chain.fweight[self.index[0]:self.index[1] + 1])
else:
return (self.index[1] - self.index[0] + 1)
def isForward(op):
return type(op) is Forward or type(op) is Forwards
class Backward(Operation):
def __init__(self, index):
self.index = index
def __repr__(self):
return "B_{i}".format(i=self.index)
def cost(self, chain: Chain):
if chain is not None:
return chain.bweight[self.index]
else:
return 1
class Loss(Operation):
def __init__(self):
pass
def __repr__(self):
return "L"
def cost(self, chain):
return 0
class MemoryAccess(Operation):
def __init__(self, index):
self.index = index
def __repr__(self):
return "{n}_{i}".format(n=self.name, i=self.index)
def cost(self, chain: Chain):
return 0
class WriteMemory(MemoryAccess):
def __init__(self, index):
super().__init__(index)
self.name = "WM"
class ReadMemory(MemoryAccess):
def __init__(self, index):
super().__init__(index)
self.name = "RM"
class DiscardMemory(MemoryAccess):
def __init__(self, index):
super().__init__(index)
self.name = "DM"
class Function:
def __init__(self, name, *args):
self.name = name
self.args = args
self.str_args = ','.join(str(v) for v in self.args)
def __repr__(self):
return "{n}({args})".format(n=self.name, args=self.str_args)
class Sequence:
def __init__(self, function):
self.sequence = [] #List of Operation and Sequence
self.function = function #Description the function (name and parameters)
def __repr__(self):
return repr(self.list_operations())
def list_operations(self):
op_list = []
for x in self.sequence:
if isinstance(x, Operation):
op_list.append(x)
else:
assert isinstance(x, Sequence)
op_list += x.list_operations()
return op_list
def insert(self, operation):
self.sequence.append(operation)
def remove(self, operation_index):
del self.sequence[operation_index]
def insert_sequence(self, sequence):
self.sequence.append(sequence)
def shift(self, value):
for x in self.sequence:
x.shift(value)
return self
def remove_useless_write(self):
if self.sequence:
if isinstance(self.sequence[0], WriteMemory):
self.remove(0)
return self
def get_makespan(self, chain):
return sum(op.cost(chain) for op in self.list_operations())
def without_suffix(self):
ops = self.list_operations()
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
try:
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
except ValueError:
last_idx = -1
if last_idx == end_of_first_phase - 1:
return (self, None)
chain_length = ops[end_of_first_phase -
1].index ## Some assumption here about the sequence (finishes with Forward_L
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
for i in range(last_idx + 1):
result.insert(ops[i])
result.insert(Loss())
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
position = end_of_first_phase + 1 + (chain_length - i)
assert type(ops[position]) is Backward
assert ops[position].index == i
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
result.insert(ops[i])
return (result, start_of_fwd_enable_chain)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment