Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e532679c
Commit
e532679c
authored
Jan 10, 2023
by
oahzxl
Browse files
Merge branch 'main' of
https://github.com/oahzxl/ColossalAI
into chunk
parents
c1492e50
7d5640b9
Changes
762
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1992 additions
and
8 deletions
+1992
-8
colossalai/amp/torch_amp/_grad_scaler.py
colossalai/amp/torch_amp/_grad_scaler.py
+7
-5
colossalai/amp/torch_amp/torch_amp.py
colossalai/amp/torch_amp/torch_amp.py
+3
-3
colossalai/auto_parallel/checkpoint/__init__.py
colossalai/auto_parallel/checkpoint/__init__.py
+3
-0
colossalai/auto_parallel/checkpoint/build_c_ext.py
colossalai/auto_parallel/checkpoint/build_c_ext.py
+16
-0
colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
+195
-0
colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
+87
-0
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c
+197
-0
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
+441
-0
colossalai/auto_parallel/checkpoint/operation.py
colossalai/auto_parallel/checkpoint/operation.py
+184
-0
colossalai/auto_parallel/meta_profiler/__init__.py
colossalai/auto_parallel/meta_profiler/__init__.py
+3
-0
colossalai/auto_parallel/meta_profiler/constants.py
colossalai/auto_parallel/meta_profiler/constants.py
+15
-0
colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py
...lai/auto_parallel/meta_profiler/meta_registry/__init__.py
+6
-0
colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
...i/auto_parallel/meta_profiler/meta_registry/activation.py
+74
-0
colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
...lel/meta_profiler/meta_registry/binary_elementwise_ops.py
+66
-0
colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
+137
-0
colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
...salai/auto_parallel/meta_profiler/meta_registry/linear.py
+172
-0
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
+103
-0
colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
...alai/auto_parallel/meta_profiler/meta_registry/pooling.py
+134
-0
colossalai/auto_parallel/meta_profiler/metainfo.py
colossalai/auto_parallel/meta_profiler/metainfo.py
+117
-0
colossalai/auto_parallel/meta_profiler/registry.py
colossalai/auto_parallel/meta_profiler/registry.py
+32
-0
No files found.
colossalai/amp/torch_amp/_grad_scaler.py
View file @
e532679c
...
...
@@ -3,16 +3,18 @@
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
# to support tensor parallel
import
torch
from
collections
import
defaultdict
,
abc
import
warnings
from
collections
import
abc
,
defaultdict
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
colossalai.context
import
ParallelMode
import
torch
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
packaging
import
version
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
class
_MultiDeviceReplicator
(
object
):
...
...
colossalai/amp/torch_amp/torch_amp.py
View file @
e532679c
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch.nn
as
nn
import
torch.cuda.amp
as
torch_amp
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch.nn.modules.loss
import
_Loss
from
torch.optim
import
Optimizer
from
._grad_scaler
import
GradScaler
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils
import
clip_grad_norm_fp32
from
._grad_scaler
import
GradScaler
class
TorchAMPOptimizer
(
ColossalaiOptimizer
):
"""A wrapper class which integrate Pytorch AMP with an optimizer
...
...
colossalai/auto_parallel/checkpoint/__init__.py
View file @
e532679c
from
.ckpt_solver_base
import
CheckpointSolverBase
from
.ckpt_solver_chen
import
CheckpointSolverChen
from
.ckpt_solver_rotor
import
CheckpointSolverRotor
colossalai/auto_parallel/checkpoint/build_c_ext.py
0 → 100644
View file @
e532679c
import
os
from
setuptools
import
Extension
,
setup
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
ext_modules
=
[
Extension
(
'rotorc'
,
sources
=
[
os
.
path
.
join
(
this_dir
,
'ckpt_solver_rotor.c'
)],
)]
setup
(
name
=
'rotor c extension'
,
version
=
'0.1'
,
description
=
'rotor c extension for faster dp computing'
,
ext_modules
=
ext_modules
,
)
colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
0 → 100644
View file @
e532679c
from
abc
import
ABC
,
abstractmethod
from
copy
import
deepcopy
from
typing
import
Any
,
List
import
torch
from
torch.fx
import
Graph
,
Node
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
(
runtime_apply
,
runtime_apply_for_iterable_object
,
runtime_comm_spec_apply
,
)
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
ActivationCheckpointCodeGen
__all___
=
[
'CheckpointSolverBase'
]
def
_copy_output
(
src
:
Graph
,
dst
:
Graph
):
"""Copy the output node from src to dst"""
for
n_src
,
n_dst
in
zip
(
src
.
nodes
,
dst
.
nodes
):
if
n_src
.
op
==
'output'
:
n_dst
.
meta
=
n_src
.
meta
def
_get_param_size
(
module
:
torch
.
nn
.
Module
):
"""Get the size of the parameters in the module"""
return
sum
([
p
.
numel
()
*
torch
.
tensor
([],
dtype
=
p
.
dtype
).
element_size
()
for
p
in
module
.
parameters
()])
class
CheckpointSolverBase
(
ABC
):
def
__init__
(
self
,
graph
:
Graph
,
free_memory
:
float
=
-
1.0
,
requires_linearize
:
bool
=
False
,
cnode
:
List
[
str
]
=
None
,
optim_multiplier
:
float
=
1.0
,
):
"""``CheckpointSolverBase`` class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for target
computing graph.
Existing Solvers:
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor)
Args:
graph (Graph): The computing graph to be optimized.
free_memory (float): Memory constraint for the solution.
requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
Warnings:
Meta information of the graph is required for any ``CheckpointSolver``.
"""
# super-dainiu: this graph is a temporary graph which can refer to
# the owning module, but we will return another deepcopy of it after
# the solver is executed.
self
.
graph
=
deepcopy
(
graph
)
self
.
graph
.
owning_module
=
graph
.
owning_module
_copy_output
(
graph
,
self
.
graph
)
self
.
graph
.
set_codegen
(
ActivationCheckpointCodeGen
())
# check if has meta information
if
any
(
len
(
node
.
meta
)
==
0
for
node
in
self
.
graph
.
nodes
):
raise
RuntimeError
(
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
)
# parameter memory = parameter size + optimizer extra weight storage
self
.
free_memory
=
free_memory
-
_get_param_size
(
self
.
graph
.
owning_module
)
*
(
optim_multiplier
+
1
)
self
.
cnode
=
cnode
self
.
requires_linearize
=
requires_linearize
if
self
.
requires_linearize
:
self
.
node_list
=
self
.
_linearize_graph
()
else
:
self
.
node_list
=
self
.
get_node_list
()
@
abstractmethod
def
solve
(
self
):
"""Solve the checkpointing problem and return the solution.
"""
pass
def
get_node_list
(
self
):
"""Get the node list.
"""
return
[[
node
]
for
node
in
self
.
graph
.
nodes
]
def
_linearize_graph
(
self
)
->
List
[
List
[
Node
]]:
"""Linearizing the graph
Args:
graph (Graph): The computing graph to be optimized.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
Do merge the inplace ops and shape-consistency ops into the previous node.
"""
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
common_ops
=
[
"getattr"
,
"getitem"
,
"size"
]
def
_is_cop
(
target
:
Any
)
->
bool
:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if
isinstance
(
target
,
str
):
return
target
in
common_ops
else
:
return
target
.
__name__
in
common_ops
def
_is_sink
()
->
bool
:
"""Check if we can free all dependencies
Returns:
bool
"""
def
_is_inplace
(
n
:
Node
):
"""Get the inplace argument from ``torch.fx.Node``
"""
inplace
=
False
if
n
.
op
==
"call_function"
:
inplace
=
n
.
kwargs
.
get
(
"inplace"
,
False
)
elif
n
.
op
==
"call_module"
:
inplace
=
getattr
(
n
.
graph
.
owning_module
.
get_submodule
(
n
.
target
),
"inplace"
,
False
)
return
inplace
def
_is_shape_consistency
(
n
:
Node
):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
"""
return
n
.
target
in
[
runtime_apply
,
runtime_apply_for_iterable_object
,
runtime_comm_spec_apply
]
return
not
sum
([
v
for
_
,
v
in
deps
.
items
()])
and
not
any
(
map
(
_is_inplace
,
n
.
users
))
and
not
any
(
map
(
_is_shape_consistency
,
n
.
users
))
# make sure that item in cnode is valid
if
self
.
cnode
:
for
name
in
self
.
cnode
:
try
:
assert
next
(
node
for
node
in
self
.
graph
.
nodes
if
node
.
name
==
name
).
op
==
"placeholder"
,
\
f
"Common node
{
name
}
is not an input of the model."
except
StopIteration
:
raise
ValueError
(
f
"Common node name
{
name
}
not in graph."
)
else
:
self
.
cnode
=
[]
deps
=
{}
node_list
=
[]
region
=
[]
for
n
in
self
.
graph
.
nodes
:
if
n
.
op
!=
"placeholder"
and
n
.
op
!=
"output"
:
for
n_par
in
n
.
all_input_nodes
:
if
n_par
.
op
!=
"placeholder"
and
n_par
.
name
not
in
self
.
cnode
:
deps
[
n_par
]
-=
1
region
.
append
(
n
)
# if the node could free all dependencies in graph
# we could begin a new node
if
_is_sink
():
node_list
.
append
(
region
)
region
=
[]
# propagate common node attr if possible
if
len
(
n
.
all_input_nodes
)
==
len
([
node
for
node
in
n
.
all_input_nodes
if
node
.
name
in
self
.
cnode
])
or
_is_cop
(
n
.
target
):
self
.
cnode
.
append
(
n
.
name
)
else
:
deps
[
n
]
=
len
([
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
return
node_list
colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
0 → 100644
View file @
e532679c
import
math
from
copy
import
deepcopy
from
typing
import
List
,
Set
,
Tuple
from
torch.fx
import
Graph
,
Node
from
colossalai.fx.profiler
import
calculate_fwd_in
,
calculate_fwd_tmp
from
.ckpt_solver_base
import
CheckpointSolverBase
__all__
=
[
'CheckpointSolverChen'
]
class
CheckpointSolverChen
(
CheckpointSolverBase
):
def
__init__
(
self
,
graph
:
Graph
,
cnode
:
List
[
str
]
=
None
,
num_grids
:
int
=
6
):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using ``CheckpointSolverChen``:
>>> solver = CheckpointSolverChen(gm.graph)
>>> chen_graph = solver.solve()
>>> gm.graph = chen_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
num_grids (int, optional): Number of grids to search for b. Defaults to 6.
"""
super
().
__init__
(
graph
,
0
,
0
,
True
,
cnode
)
self
.
num_grids
=
num_grids
def
solve
(
self
)
->
Graph
:
"""Solve the checkpointing problem using Algorithm 3.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
checkpointable_op
=
[
'call_module'
,
'call_method'
,
'call_function'
,
'get_attr'
]
ckpt
=
self
.
grid_search
()
for
i
,
seg
in
enumerate
(
ckpt
):
for
idx
in
range
(
*
seg
):
nodes
=
self
.
node_list
[
idx
]
for
n
in
nodes
:
if
n
.
op
in
checkpointable_op
:
n
.
meta
[
'activation_checkpoint'
]
=
i
return
deepcopy
(
self
.
graph
)
def
run_chen_greedy
(
self
,
b
:
int
=
0
)
->
Tuple
[
Set
,
int
]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt_intv
=
[]
temp
=
0
x
=
0
y
=
0
prev_idx
=
2
for
idx
,
nodes
in
enumerate
(
self
.
node_list
):
for
n
in
nodes
:
n
:
Node
temp
+=
calculate_fwd_in
(
n
)
+
calculate_fwd_tmp
(
n
)
y
=
max
(
y
,
temp
)
if
temp
>
b
and
idx
>
prev_idx
:
x
+=
calculate_fwd_in
(
nodes
[
0
])
temp
=
0
ckpt_intv
.
append
((
prev_idx
,
idx
+
1
))
prev_idx
=
idx
+
1
return
ckpt_intv
,
math
.
floor
(
math
.
sqrt
(
x
*
y
))
def
grid_search
(
self
)
->
Set
:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
"""
_
,
b_approx
=
self
.
run_chen_greedy
(
0
)
b_min
,
b_max
=
math
.
floor
(
b_approx
/
math
.
sqrt
(
2
)),
math
.
ceil
(
b_approx
*
math
.
sqrt
(
2
))
b_opt
=
math
.
inf
for
b
in
range
(
b_min
,
b_max
,
(
b_max
-
b_min
)
//
self
.
num_grids
):
ckpt_intv
,
b_approx
=
self
.
run_chen_greedy
(
b
)
if
b_approx
<
b_opt
:
b_opt
=
b_approx
ckpt_opt
=
ckpt_intv
return
ckpt_opt
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c
0 → 100644
View file @
e532679c
#define PY_SSIZE_T_CLEAN
#include <Python.h>
long
*
PySequenceToLongArray
(
PyObject
*
pylist
)
{
if
(
!
(
pylist
&&
PySequence_Check
(
pylist
)))
return
NULL
;
Py_ssize_t
len
=
PySequence_Size
(
pylist
);
long
*
result
=
(
long
*
)
calloc
(
len
+
1
,
sizeof
(
long
));
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
++
i
)
{
PyObject
*
item
=
PySequence_GetItem
(
pylist
,
i
);
result
[
i
]
=
PyLong_AsLong
(
item
);
Py_DECREF
(
item
);
}
result
[
len
]
=
0
;
return
result
;
}
double
*
PySequenceToDoubleArray
(
PyObject
*
pylist
)
{
if
(
!
(
pylist
&&
PySequence_Check
(
pylist
)))
return
NULL
;
Py_ssize_t
len
=
PySequence_Size
(
pylist
);
double
*
result
=
(
double
*
)
calloc
(
len
+
1
,
sizeof
(
double
));
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
++
i
)
{
PyObject
*
item
=
PySequence_GetItem
(
pylist
,
i
);
result
[
i
]
=
PyFloat_AsDouble
(
item
);
Py_DECREF
(
item
);
}
result
[
len
]
=
0
;
return
result
;
}
long
*
getLongArray
(
PyObject
*
container
,
const
char
*
attributeName
)
{
PyObject
*
sequence
=
PyObject_GetAttrString
(
container
,
attributeName
);
long
*
result
=
PySequenceToLongArray
(
sequence
);
Py_DECREF
(
sequence
);
return
result
;
}
double
*
getDoubleArray
(
PyObject
*
container
,
const
char
*
attributeName
)
{
PyObject
*
sequence
=
PyObject_GetAttrString
(
container
,
attributeName
);
double
*
result
=
PySequenceToDoubleArray
(
sequence
);
Py_DECREF
(
sequence
);
return
result
;
}
static
PyObject
*
computeTable
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
chainParam
;
int
mmax
;
if
(
!
PyArg_ParseTuple
(
args
,
"Oi"
,
&
chainParam
,
&
mmax
))
return
NULL
;
double
*
ftime
=
getDoubleArray
(
chainParam
,
"ftime"
);
if
(
!
ftime
)
return
NULL
;
double
*
btime
=
getDoubleArray
(
chainParam
,
"btime"
);
if
(
!
btime
)
return
NULL
;
long
*
x
=
getLongArray
(
chainParam
,
"x"
);
if
(
!
x
)
return
NULL
;
long
*
xbar
=
getLongArray
(
chainParam
,
"xbar"
);
if
(
!
xbar
)
return
NULL
;
long
*
ftmp
=
getLongArray
(
chainParam
,
"btmp"
);
if
(
!
ftmp
)
return
NULL
;
long
*
btmp
=
getLongArray
(
chainParam
,
"btmp"
);
if
(
!
btmp
)
return
NULL
;
long
chainLength
=
PyObject_Length
(
chainParam
);
if
(
!
chainLength
)
return
NULL
;
#define COST_TABLE(m, i, l) \
costTable[(m) * (chainLength + 1) * (chainLength + 1) + \
(i) * (chainLength + 1) + (l)]
double
*
costTable
=
(
double
*
)
calloc
(
(
mmax
+
1
)
*
(
chainLength
+
1
)
*
(
chainLength
+
1
),
sizeof
(
double
));
#define BACK_PTR(m, i, l) \
backPtr[(m) * (chainLength + 1) * (chainLength + 1) + \
(i) * (chainLength + 1) + (l)]
long
*
backPtr
=
(
long
*
)
calloc
(
(
mmax
+
1
)
*
(
chainLength
+
1
)
*
(
chainLength
+
1
),
sizeof
(
long
));
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
i
=
0
;
i
<=
chainLength
;
++
i
)
if
((
m
>=
x
[
i
+
1
]
+
xbar
[
i
+
1
]
+
btmp
[
i
])
&&
(
m
>=
x
[
i
+
1
]
+
xbar
[
i
+
1
]
+
ftmp
[
i
]))
COST_TABLE
(
m
,
i
,
i
)
=
ftime
[
i
]
+
btime
[
i
];
else
COST_TABLE
(
m
,
i
,
i
)
=
INFINITY
;
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
d
=
1
;
d
<=
chainLength
;
++
d
)
{
for
(
long
i
=
0
;
i
<=
chainLength
-
d
;
++
i
)
{
long
idx
=
i
+
d
;
long
mmin
=
x
[
idx
+
1
]
+
x
[
i
+
1
]
+
ftmp
[
i
];
if
(
idx
>
i
+
1
)
{
long
maxCostFWD
=
0
;
for
(
long
j
=
i
+
1
;
j
<
idx
;
j
++
)
{
maxCostFWD
=
fmaxl
(
maxCostFWD
,
x
[
j
]
+
x
[
j
+
1
]
+
ftmp
[
j
]);
}
mmin
=
fmaxl
(
mmin
,
x
[
idx
+
1
]
+
maxCostFWD
);
}
if
((
m
>=
mmin
))
{
long
bestLeaf
=
-
1
;
double
sumFw
=
0
;
double
bestLeafCost
=
INFINITY
;
for
(
long
j
=
i
+
1
;
j
<=
idx
;
++
j
)
{
sumFw
+=
ftime
[
j
-
1
];
if
(
m
>=
x
[
j
])
{
double
cost
=
sumFw
+
COST_TABLE
(
m
-
x
[
j
],
j
,
idx
)
+
COST_TABLE
(
m
,
i
,
j
-
1
);
if
(
cost
<
bestLeafCost
)
{
bestLeafCost
=
cost
;
bestLeaf
=
j
;
}
}
}
double
chainCost
=
INFINITY
;
if
(
m
>=
xbar
[
i
+
1
])
chainCost
=
COST_TABLE
(
m
,
i
,
i
)
+
COST_TABLE
(
m
-
xbar
[
i
+
1
],
i
+
1
,
idx
);
if
(
bestLeafCost
<=
chainCost
)
{
COST_TABLE
(
m
,
i
,
idx
)
=
bestLeafCost
;
BACK_PTR
(
m
,
i
,
idx
)
=
bestLeaf
;
}
else
{
COST_TABLE
(
m
,
i
,
idx
)
=
chainCost
;
BACK_PTR
(
m
,
i
,
idx
)
=
-
1
;
}
}
else
COST_TABLE
(
m
,
i
,
idx
)
=
INFINITY
;
}
}
free
(
ftime
);
free
(
btime
);
free
(
x
);
free
(
xbar
);
free
(
ftmp
);
free
(
btmp
);
PyObject
*
pyCostTable
=
PyList_New
(
mmax
+
1
);
PyObject
*
pyBackPtr
=
PyList_New
(
mmax
+
1
);
// Convert the result into Python world
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
{
PyObject
*
pyCostTable_m
=
PyList_New
(
chainLength
+
1
);
PyList_SET_ITEM
(
pyCostTable
,
m
,
pyCostTable_m
);
PyObject
*
pyBackPtr_m
=
PyList_New
(
chainLength
+
1
);
PyList_SET_ITEM
(
pyBackPtr
,
m
,
pyBackPtr_m
);
for
(
long
i
=
0
;
i
<=
chainLength
;
++
i
)
{
PyObject
*
pyCostTable_m_i
=
PyDict_New
();
PyList_SET_ITEM
(
pyCostTable_m
,
i
,
pyCostTable_m_i
);
PyObject
*
pyBackPtr_m_i
=
PyDict_New
();
PyList_SET_ITEM
(
pyBackPtr_m
,
i
,
pyBackPtr_m_i
);
for
(
long
l
=
i
;
l
<=
chainLength
;
++
l
)
{
PyObject
*
pyVar_l
=
PyLong_FromLong
(
l
);
PyObject
*
pyCostTable_m_i_l
=
PyFloat_FromDouble
(
COST_TABLE
(
m
,
i
,
l
));
PyDict_SetItem
(
pyCostTable_m_i
,
pyVar_l
,
pyCostTable_m_i_l
);
Py_DECREF
(
pyCostTable_m_i_l
);
PyObject
*
pyBackPtr_m_i_l
;
if
(
BACK_PTR
(
m
,
i
,
l
)
<
0
)
pyBackPtr_m_i_l
=
Py_BuildValue
(
"(O)"
,
Py_True
);
else
pyBackPtr_m_i_l
=
Py_BuildValue
(
"(Ol)"
,
Py_False
,
BACK_PTR
(
m
,
i
,
l
));
PyDict_SetItem
(
pyBackPtr_m_i
,
pyVar_l
,
pyBackPtr_m_i_l
);
Py_DECREF
(
pyBackPtr_m_i_l
);
Py_DECREF
(
pyVar_l
);
}
}
}
free
(
costTable
);
free
(
backPtr
);
PyObject
*
result
=
PyTuple_Pack
(
2
,
pyCostTable
,
pyBackPtr
);
Py_DECREF
(
pyCostTable
);
Py_DECREF
(
pyBackPtr
);
return
result
;
}
static
PyMethodDef
rotorMethods
[]
=
{
{
"compute_table"
,
computeTable
,
METH_VARARGS
,
"Compute the optimal table with the rotor algorithm."
},
{
NULL
,
NULL
,
0
,
NULL
}
/* Sentinel */
};
static
struct
PyModuleDef
rotorModule
=
{
PyModuleDef_HEAD_INIT
,
"rotorc"
,
/* name of module */
"A simple implementation of dynamic programming algorithm rotor with C in "
"https://hal.inria.fr/hal-02352969. Some code are adapted from "
"https://gitlab.inria.fr/hiepacs/rotor."
,
/* module documentation, may be
NULL */
-
1
,
/* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
rotorMethods
};
PyMODINIT_FUNC
PyInit_rotorc
(
void
)
{
return
PyModule_Create
(
&
rotorModule
);
}
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
0 → 100644
View file @
e532679c
from
copy
import
deepcopy
from
typing
import
Any
,
Dict
,
List
,
Tuple
from
torch
import
Tensor
from
torch.fx
import
Graph
,
Node
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply
,
runtime_comm_spec_apply
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
_find_nested_ckpt_regions
from
colossalai.fx.profiler
import
(
activation_size
,
calculate_bwd_time
,
calculate_fwd_out
,
calculate_fwd_time
,
calculate_fwd_tmp
,
)
from
colossalai.logging
import
get_dist_logger
from
.ckpt_solver_base
import
CheckpointSolverBase
from
.operation
import
Backward
,
Chain
,
ForwardCheck
,
ForwardEnable
,
ForwardNograd
,
Loss
,
Sequence
__all__
=
[
'CheckpointSolverRotor'
]
class
CheckpointSolverRotor
(
CheckpointSolverBase
):
def
__init__
(
self
,
graph
:
Graph
,
free_memory
:
float
=
-
1
,
cnode
:
List
[
str
]
=
None
,
memory_slots
:
int
=
500
,
optim_multiplier
:
float
=
1.0
):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
Usage:
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using ``CheckpointSolverRotor``:
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
free_memory (float, optional): Memory constraint for the solution, unit is byte.
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
"""
super
().
__init__
(
graph
,
free_memory
,
True
,
cnode
,
optim_multiplier
)
self
.
memory_slots
=
memory_slots
# construct chain
unit
=
self
.
free_memory
//
self
.
memory_slots
self
.
chain
=
self
.
_construct_chain
(
self
.
graph
,
self
.
node_list
)
self
.
chain
.
discretize_all
(
unit
)
self
.
cost_table
=
None
self
.
back_ptr
=
None
self
.
sequence
=
None
def
solve
(
self
,
force_python
:
bool
=
False
,
verbose
:
bool
=
False
)
->
Graph
:
"""Solve the checkpointing problem using rotor algorithm.
Args:
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
verbose (bool, optional): Print verbose information. Defaults to False.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
chain
=
self
.
chain
# compute cost table
if
force_python
:
self
.
cost_table
,
self
.
back_ptr
=
self
.
_compute_table
(
chain
,
self
.
memory_slots
)
else
:
self
.
cost_table
,
self
.
back_ptr
=
self
.
_compute_table_c
(
chain
,
self
.
memory_slots
)
if
verbose
:
self
.
print_chain
()
# backtrack
try
:
self
.
sequence
=
self
.
_backtrack
(
chain
,
0
,
len
(
chain
),
self
.
memory_slots
-
chain
.
x
[
0
],
self
.
cost_table
,
self
.
back_ptr
)
self
.
_annotate_from_sequence
(
self
.
sequence
,
self
.
node_list
)
except
ValueError
as
e
:
# using logger to annonce that the solver is failed
logger
=
get_dist_logger
()
logger
.
warning
(
f
'Checkpoint solver failed:
{
e
}
'
)
raise
ValueError
if
verbose
:
self
.
print_sequence
()
return
deepcopy
(
self
.
graph
)
def
print_chain
(
self
):
print
(
'[input]'
,
self
.
chain
.
x
[
0
],
self
.
chain
.
xbar
[
0
],
self
.
chain
.
ftmp
[
0
],
self
.
chain
.
btmp
[
0
])
for
idx
in
range
(
len
(
self
.
node_list
)
-
1
):
print
(
self
.
node_list
[
idx
],
self
.
chain
.
x
[
idx
+
1
],
self
.
chain
.
xbar
[
idx
+
1
],
self
.
chain
.
ftmp
[
idx
],
self
.
chain
.
btmp
[
idx
])
print
(
f
'Chain =
{
self
.
chain
}
'
)
def
print_sequence
(
self
):
print
(
f
'Sequence =
{
self
.
sequence
}
'
)
@
classmethod
def
_construct_chain
(
cls
,
graph
:
Graph
,
node_list
:
List
[
List
[
Node
]])
->
Chain
:
input_tensors
=
cls
.
_extract_input
(
graph
)
ftime
,
btime
,
ftmp
,
btmp
=
list
(),
list
(),
list
(),
list
()
xbar
,
x
=
[
activation_size
(
input_tensors
)],
[
activation_size
(
input_tensors
)]
for
node
in
node_list
:
node_info
=
cls
.
_extract_node_info
(
node
)
ftime
.
append
(
node_info
[
0
])
btime
.
append
(
node_info
[
1
])
x
.
append
(
node_info
[
2
])
xbar
.
append
(
node_info
[
3
])
ftmp
.
append
(
node_info
[
4
])
btmp
.
append
(
node_info
[
5
])
# currently we view loss backward temp as zero
btime
.
append
(
0
)
btmp
.
append
(
0
)
return
Chain
(
ftime
,
btime
,
x
,
xbar
,
ftmp
,
btmp
)
@
classmethod
def
_extract_node_info
(
cls
,
node
:
List
[
Node
])
->
Tuple
[
int
,
...]:
"""Extract node info from a list of nodes"""
xbar
=
0
ftime
=
0
btime
=
0
fwd_mem_peak
=
0
for
n
in
node
:
assert
isinstance
(
n
,
Node
),
f
'
{
n
}
is not a Node'
if
n
.
target
==
runtime_apply
or
n
.
target
==
runtime_comm_spec_apply
:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar
+=
n
.
meta
[
'fwd_mem_out'
]
fwd_mem_peak
=
max
(
fwd_mem_peak
,
xbar
+
n
.
meta
[
'fwd_mem_tmp'
])
else
:
xbar
+=
calculate_fwd_tmp
(
n
)
+
calculate_fwd_out
(
n
)
fwd_mem_peak
=
max
(
fwd_mem_peak
,
xbar
+
n
.
meta
[
'fwd_mem_tmp'
]
+
cls
.
_extract_unused_output
(
n
))
# minimum flop count is required
ftime
+=
max
(
calculate_fwd_time
(
n
),
1.0
)
btime
+=
max
(
calculate_bwd_time
(
n
),
1.0
)
x
=
calculate_fwd_out
(
node
[
-
1
])
xbar
=
max
(
x
,
xbar
)
ftmp
=
fwd_mem_peak
-
xbar
btmp
=
cls
.
_extract_btmp
(
node
)
return
ftime
,
btime
,
x
,
xbar
,
ftmp
,
btmp
@
staticmethod
def
_extract_input
(
graph
:
Graph
)
->
Tuple
[
Tensor
,
...]:
"""Extract input tensors from a Graph"""
input_tensors
=
[]
for
node
in
graph
.
nodes
:
if
node
.
op
==
'placeholder'
:
input_tensors
.
append
(
node
.
meta
[
'fwd_out'
])
return
input_tensors
@
staticmethod
def
_extract_unused_output
(
node
:
Node
)
->
int
:
"""Extract unused output from `torch.fx.Node`"""
return
activation_size
(
node
.
meta
[
'fwd_out'
])
-
calculate_fwd_out
(
node
)
@
staticmethod
def
_extract_btmp
(
node
:
List
[
Node
])
->
int
:
"""Extract btmp from a list of nodes"""
def
_extract_deps_size
():
deps_size
=
0
for
k
,
v
in
deps
.
items
():
k
:
Node
if
v
>
0
:
deps_size
+=
k
.
meta
[
'bwd_mem_out'
]
if
v
==
float
(
'-inf'
):
deps_size
-=
calculate_fwd_tmp
(
k
)
+
calculate_fwd_out
(
k
)
return
deps_size
btmp
=
0
deps
=
{}
for
n
in
reversed
(
node
):
deps
[
n
]
=
len
(
n
.
all_input_nodes
)
btmp
=
max
(
btmp
,
_extract_deps_size
()
+
n
.
meta
[
'bwd_mem_tmp'
])
for
child
in
n
.
users
:
if
child
in
deps
:
deps
[
child
]
-=
1
if
deps
[
child
]
<=
0
:
deps
[
child
]
=
float
(
'-inf'
)
# free
return
btmp
@
staticmethod
def
_compute_table
(
chain
:
Chain
,
mmax
:
int
)
->
Tuple
:
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
mmax (int): Maximum number of memory slots.
Returns:
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
of length j
"""
ftime
=
chain
.
ftime
+
[
0.0
]
btime
=
chain
.
btime
x
=
chain
.
x
+
[
0
]
xbar
=
chain
.
xbar
+
[
0
]
ftmp
=
chain
.
ftmp
+
[
0
]
btmp
=
chain
.
btmp
+
[
0
]
# Build table
cost_table
=
[[{}
for
_
in
range
(
len
(
chain
)
+
1
)]
for
_
in
range
(
mmax
+
1
)]
back_ptr
=
[[{}
for
_
in
range
(
len
(
chain
)
+
1
)]
for
_
in
range
(
mmax
+
1
)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for
m
in
range
(
mmax
+
1
):
for
i
in
range
(
len
(
chain
)
+
1
):
limit
=
max
(
x
[
i
+
1
]
+
xbar
[
i
+
1
]
+
ftmp
[
i
],
x
[
i
+
1
]
+
xbar
[
i
+
1
]
+
btmp
[
i
])
if
m
>=
limit
:
# Equation (1)
cost_table
[
m
][
i
][
i
]
=
ftime
[
i
]
+
btime
[
i
]
else
:
cost_table
[
m
][
i
][
i
]
=
float
(
"inf"
)
# Compute everything
for
m
in
range
(
mmax
+
1
):
for
d
in
range
(
1
,
len
(
chain
)
+
1
):
for
i
in
range
(
len
(
chain
)
+
1
-
d
):
idx
=
i
+
d
mmin
=
x
[
idx
+
1
]
+
x
[
i
+
1
]
+
ftmp
[
i
]
if
idx
>
i
+
1
:
mmin
=
max
(
mmin
,
x
[
idx
+
1
]
+
max
(
x
[
j
]
+
x
[
j
+
1
]
+
ftmp
[
j
]
for
j
in
range
(
i
+
1
,
idx
)))
if
m
<
mmin
:
cost_table
[
m
][
i
][
idx
]
=
float
(
"inf"
)
else
:
leaf_checkpoints
=
[(
j
,
sum
(
ftime
[
i
:
j
])
+
cost_table
[
m
-
x
[
j
]][
j
][
idx
]
+
cost_table
[
m
][
i
][
j
-
1
])
for
j
in
range
(
i
+
1
,
idx
+
1
)
if
m
>=
x
[
j
]]
if
leaf_checkpoints
:
best_leaf
=
min
(
leaf_checkpoints
,
key
=
lambda
t
:
t
[
1
])
else
:
best_leaf
=
None
if
m
>=
xbar
[
i
+
1
]:
chain_checkpoint
=
cost_table
[
m
][
i
][
i
]
+
cost_table
[
m
-
xbar
[
i
+
1
]][
i
+
1
][
idx
]
else
:
chain_checkpoint
=
float
(
"inf"
)
if
best_leaf
and
best_leaf
[
1
]
<=
chain_checkpoint
:
cost_table
[
m
][
i
][
idx
]
=
best_leaf
[
1
]
back_ptr
[
m
][
i
][
idx
]
=
(
False
,
best_leaf
[
0
])
else
:
cost_table
[
m
][
i
][
idx
]
=
chain_checkpoint
back_ptr
[
m
][
i
][
idx
]
=
(
True
,)
return
cost_table
,
back_ptr
@
staticmethod
def
_compute_table_c
(
chain
:
Chain
,
mmax
:
int
)
->
Tuple
:
try
:
from
.rotorc
import
compute_table
# build module if module not found
except
ModuleNotFoundError
:
import
os
import
subprocess
import
sys
logger
=
get_dist_logger
()
logger
.
info
(
"rotorc hasn't been built! Building library..."
,
ranks
=
[
0
])
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
result
=
subprocess
.
Popen
(
[
f
"
{
sys
.
executable
}
"
,
f
"
{
os
.
path
.
join
(
this_dir
,
'build_c_ext.py'
)
}
"
,
"build_ext"
,
f
"--build-lib=
{
this_dir
}
"
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
)
if
result
.
wait
()
==
0
:
logger
.
info
(
"rotorc has been built!"
,
ranks
=
[
0
])
from
.rotorc
import
compute_table
else
:
logger
.
warning
(
"rotorc built failed! Using python version!"
,
ranks
=
[
0
])
return
CheckpointSolverRotor
.
_compute_table
(
chain
,
mmax
)
return
compute_table
(
chain
,
mmax
)
@
staticmethod
def
_backtrack
(
chain
:
Chain
,
lhs
:
int
,
rhs
:
int
,
budget
:
int
,
cost_table
:
List
[
Any
],
back_ptr
:
List
[
Any
])
->
"Sequence"
:
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
lhs (int): The left index of the interval to backtrack.
rhs (int): The right index of the interval to backtrack.
budget (int): The memory budget for processing this interval.
cost_table (List[Any]): See ``._compute_table()`` for definitions
back_ptr (List[Any]): See ``._compute_table()`` for definitions
Raises:
ValueError: Can not process the chain.
Returns:
sequence (Sequence): The sequence of executing nodes with checkpoints.
"""
if
budget
<=
0
:
raise
ValueError
(
f
"Can not process a chain with negative memory
{
budget
}
"
)
elif
cost_table
[
budget
][
lhs
][
rhs
]
==
float
(
"inf"
):
raise
ValueError
(
f
"Can not process this chain from index
{
lhs
}
to
{
rhs
}
with memory
{
budget
}
"
)
sequence
=
Sequence
()
if
rhs
==
lhs
:
if
lhs
==
len
(
chain
):
sequence
+=
[
Loss
()]
else
:
sequence
+=
[
ForwardEnable
(
lhs
),
Backward
(
lhs
)]
return
sequence
if
back_ptr
[
budget
][
lhs
][
rhs
][
0
]:
sequence
+=
[
ForwardEnable
(
lhs
),
CheckpointSolverRotor
.
_backtrack
(
chain
,
lhs
+
1
,
rhs
,
budget
-
chain
.
xbar
[
lhs
+
1
],
cost_table
,
back_ptr
),
Backward
(
lhs
),
]
else
:
best_leaf
=
back_ptr
[
budget
][
lhs
][
rhs
][
1
]
sequence
+=
[
ForwardCheck
(
lhs
)]
sequence
+=
[
ForwardNograd
(
k
)
for
k
in
range
(
lhs
+
1
,
best_leaf
)]
sequence
+=
[
CheckpointSolverRotor
.
_backtrack
(
chain
,
best_leaf
,
rhs
,
budget
-
chain
.
x
[
best_leaf
],
cost_table
,
back_ptr
),
CheckpointSolverRotor
.
_backtrack
(
chain
,
lhs
,
best_leaf
-
1
,
budget
,
cost_table
,
back_ptr
),
]
return
sequence
@
staticmethod
def
_annotate_from_sequence
(
sequence
:
Sequence
,
node_list
:
List
[
List
[
Node
]]):
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.
Args:
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
node_list (List[List[Node]]): The list of nodes to annotate.
"""
op_list
=
sequence
.
list_operations
()
loss_op
=
next
(
op
for
op
in
op_list
if
isinstance
(
op
,
Loss
))
fwd_list
=
op_list
[:
op_list
.
index
(
loss_op
)]
bwd_list
=
op_list
[
op_list
.
index
(
loss_op
)
+
1
:]
ckpt_idx
=
0
in_ckpt
=
False
ckpt_region
=
[]
# forward annotation
for
idx
,
op
in
enumerate
(
fwd_list
,
0
):
if
in_ckpt
:
if
isinstance
(
op
,
ForwardNograd
):
ckpt_region
.
append
(
idx
)
elif
isinstance
(
op
,
ForwardEnable
):
in_ckpt
=
False
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
meta
[
'activation_checkpoint'
]
=
[
ckpt_idx
]
ckpt_idx
+=
1
ckpt_region
=
[]
elif
isinstance
(
op
,
ForwardCheck
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
meta
[
'activation_checkpoint'
]
=
[
ckpt_idx
]
ckpt_idx
+=
1
ckpt_region
=
[
idx
]
else
:
if
isinstance
(
op
,
ForwardCheck
):
in_ckpt
=
True
ckpt_region
.
append
(
idx
)
# annotate the backward if there is any nested activation checkpoint
in_recompute
=
False
for
op
in
bwd_list
:
if
in_recompute
:
if
isinstance
(
op
,
ForwardNograd
):
ckpt_region
.
append
(
op
.
index
)
elif
isinstance
(
op
,
ForwardEnable
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
meta
[
'activation_checkpoint'
].
append
(
ckpt_idx
)
ckpt_idx
+=
1
ckpt_region
=
[]
elif
isinstance
(
op
,
ForwardCheck
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
meta
[
'activation_checkpoint'
].
append
(
ckpt_idx
)
ckpt_idx
+=
1
ckpt_region
=
[
op
.
index
]
elif
isinstance
(
op
,
Backward
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
meta
[
'activation_checkpoint'
].
append
(
ckpt_idx
)
in_recompute
=
False
else
:
if
not
isinstance
(
op
,
Backward
):
in_recompute
=
True
ckpt_idx
=
0
ckpt_region
=
[]
if
isinstance
(
op
,
ForwardCheck
):
ckpt_region
.
append
(
op
.
index
)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list
=
[]
for
node
in
node_list
:
op_list
+=
node
ckpt_regions
=
_find_nested_ckpt_regions
(
op_list
)
for
(
start_idx
,
end_idx
)
in
ckpt_regions
:
nested_length
=
max
(
len
(
op_list
[
idx
].
meta
[
'activation_checkpoint'
])
for
idx
in
range
(
start_idx
,
end_idx
+
1
))
for
idx
in
range
(
start_idx
,
end_idx
+
1
):
op_list
[
idx
].
meta
[
'activation_checkpoint'
]
+=
[
None
]
*
(
nested_length
-
len
(
op_list
[
idx
].
meta
[
'activation_checkpoint'
]))
colossalai/auto_parallel/checkpoint/operation.py
0 → 100644
View file @
e532679c
import
math
from
abc
import
ABC
from
typing
import
Any
,
Iterable
,
List
from
torch.utils._pytree
import
tree_map
class
Chain
:
def
__init__
(
self
,
ftime
:
List
[
float
],
btime
:
List
[
float
],
x
:
List
[
int
],
xbar
:
List
[
int
],
ftmp
:
List
[
int
],
btmp
:
List
[
int
],
check_consistency
:
bool
=
True
):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
Args:
ftime (List[float]): The forward time of each node.
btime (List[float]): The backward time of each node.
x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper.
xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper.
ftmp (List[int]): The temporary forward memory of each node.
btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget.
check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True.
"""
self
.
ftime
=
ftime
self
.
btime
=
btime
self
.
x
=
x
self
.
xbar
=
xbar
self
.
ftmp
=
ftmp
self
.
btmp
=
btmp
if
check_consistency
and
not
self
.
check_lengths
():
raise
AttributeError
(
"In Chain, input lists do not have consistent lengths"
)
def
check_lengths
(
self
):
return
((
len
(
self
.
ftime
)
==
len
(
self
))
and
(
len
(
self
.
btime
)
==
len
(
self
)
+
1
)
and
(
len
(
self
.
x
)
==
len
(
self
)
+
1
)
and
(
len
(
self
.
ftmp
)
==
len
(
self
))
and
(
len
(
self
.
btmp
)
==
len
(
self
)
+
1
)
and
(
len
(
self
.
xbar
)
==
len
(
self
)
+
1
))
def
__repr__
(
self
):
chain_list
=
[]
for
i
in
range
(
len
(
self
)):
chain_list
.
append
((
self
.
ftime
[
i
],
self
.
btime
[
i
],
self
.
x
[
i
],
self
.
xbar
[
i
],
self
.
ftmp
[
i
],
self
.
btmp
[
i
]))
i
=
len
(
self
)
chain_list
.
append
((
None
,
self
.
btime
[
i
],
self
.
x
[
i
],
self
.
xbar
[
i
],
None
,
self
.
btmp
[
i
]))
return
chain_list
.
__repr__
()
def
__len__
(
self
):
return
len
(
self
.
ftime
)
def
discretize_all
(
self
,
unit
:
int
):
"""Discretize the chain into a list of chains according to unit size."""
discretizer
=
lambda
val
:
math
.
ceil
(
val
/
unit
)
self
.
x
=
tree_map
(
discretizer
,
self
.
x
)
self
.
xbar
=
tree_map
(
discretizer
,
self
.
xbar
)
self
.
ftmp
=
tree_map
(
discretizer
,
self
.
ftmp
)
self
.
btmp
=
tree_map
(
discretizer
,
self
.
btmp
)
class
Operation
(
ABC
):
name
=
"Op"
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
name
}
_
{
self
.
index
}
"
def
shift
(
self
,
value
):
if
type
(
self
.
index
)
is
tuple
:
self
.
index
=
tuple
(
x
+
value
for
x
in
self
.
index
)
else
:
self
.
index
+=
value
class
Forward
(
Operation
):
name
=
"F"
def
__init__
(
self
,
index
):
self
.
index
=
index
def
cost
(
self
,
chain
:
Chain
):
if
chain
is
not
None
:
return
chain
.
ftime
[
self
.
index
]
else
:
return
1
class
ForwardEnable
(
Forward
):
name
=
"Fe"
class
ForwardNograd
(
Forward
):
name
=
"Fn"
class
ForwardCheck
(
Forward
):
name
=
"CF"
class
Forwards
(
Operation
):
def
__init__
(
self
,
start
,
end
):
self
.
index
=
(
start
,
end
)
def
__repr__
(
self
):
return
"F_{i}->{j}"
.
format
(
i
=
self
.
index
[
0
],
j
=
self
.
index
[
1
])
def
cost
(
self
,
chain
:
Chain
):
if
chain
is
not
None
:
return
sum
(
chain
.
ftime
[
self
.
index
[
0
]:
self
.
index
[
1
]
+
1
])
else
:
return
(
self
.
index
[
1
]
-
self
.
index
[
0
]
+
1
)
def
isForward
(
op
):
return
type
(
op
)
is
Forward
or
type
(
op
)
is
Forwards
class
Backward
(
Operation
):
name
=
"B"
def
__init__
(
self
,
index
):
self
.
index
=
index
def
cost
(
self
,
chain
:
Chain
):
if
chain
is
not
None
:
return
chain
.
btime
[
self
.
index
]
else
:
return
1
class
Loss
(
Operation
):
def
__init__
(
self
):
pass
def
__repr__
(
self
):
return
"L"
def
cost
(
self
,
chain
):
return
0
class
MemoryAccess
(
Operation
):
name
=
"MA"
def
__init__
(
self
,
index
):
self
.
index
=
index
def
cost
(
self
,
chain
:
Chain
):
return
0
class
WriteMemory
(
MemoryAccess
):
name
=
"WM"
class
ReadMemory
(
MemoryAccess
):
name
=
"RM"
class
DiscardMemory
(
MemoryAccess
):
name
=
"DM"
class
Sequence
(
list
):
def
__init__
(
self
):
super
().
__init__
()
def
__repr__
(
self
):
return
repr
(
self
.
list_operations
())
def
list_operations
(
self
):
op_list
=
[]
for
x
in
self
:
if
isinstance
(
x
,
Operation
):
op_list
.
append
(
x
)
else
:
assert
isinstance
(
x
,
Sequence
)
op_list
+=
x
.
list_operations
()
return
op_list
colossalai/auto_parallel/meta_profiler/__init__.py
0 → 100644
View file @
e532679c
from
.meta_registry
import
*
from
.metainfo
import
*
from
.registry
import
meta_register
colossalai/auto_parallel/meta_profiler/constants.py
0 → 100644
View file @
e532679c
import
operator
import
torch
import
torch.nn
as
nn
from
..tensor_shard.constants
import
*
# list of inplace module
INPLACE_MODULE
=
[
nn
.
ReLU
]
# list of inplace operations
INPLACE_OPS
=
[
torch
.
flatten
]
# list of operations that do not save forward activations
NO_SAVE_ACTIVATION
=
[
torch
.
add
,
torch
.
sub
,
operator
.
add
,
operator
.
sub
]
colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py
0 → 100644
View file @
e532679c
from
.activation
import
*
from
.binary_elementwise_ops
import
*
from
.conv
import
*
from
.linear
import
*
from
.norm
import
*
from
.pooling
import
*
colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
0 → 100644
View file @
e532679c
from
typing
import
List
,
Tuple
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
from
colossalai.fx.profiler.memory_utils
import
activation_size
from
colossalai.fx.profiler.opcount
import
flop_mapping
from
..registry
import
meta_register
__all__
=
[
"relu_meta_info"
]
@
meta_register
.
register
(
torch
.
nn
.
ReLU
)
def
relu_meta_info
(
*
args
,
**
kwargs
)
->
Tuple
[
TrainCycleItem
,
TrainCycleItem
,
List
[
torch
.
Tensor
]]:
"""torch.nn.ReLU metainfo generator
The aten graph of torch.nn.ReLU is
graph():
%input_2 : [#users=1] = placeholder[target=placeholder](default=)
%relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {})
%threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor
=
args
[
0
].
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
is_inplace
=
kwargs
.
get
(
"inplace"
,
False
)
# construct input args for forward
fwd_in_args
=
[
input_tensor
]
# construct input args for backward
bwd_in_args
=
[
output_tensor
]
# calculate cost
# the fwd op with compute cost is relu.default
# the bwd op with compute cost is threshold_backward
# calculate compute cost
fwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
relu
.
default
](
fwd_in_args
,
(
output_tensor
,))
bwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
threshold_backward
.
default
](
bwd_in_args
,
(
input_tensor
,))
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
# calculate memory cost
# NOTE: the inplace ReLU don't have forward memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
(
input_tensor
)
if
is_inplace
else
activation_size
([
output_tensor
,
input_tensor
]),
parameter
=
0
,
temp
=
0
,
buffer
=
0
)
bwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
(
input_tensor
),
parameter
=
0
,
temp
=
0
,
buffer
=
0
)
# total cost is the sum of forward and backward cost
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
# store fwd_in, fwd_buffer, fwd_out
# NOTE: It might seems a little bit weird here, we just want to align it with the older version
# of MetaInfoProp. In the future we might modify this part to make it clearer.
fwd_in
=
[]
fwd_buffer
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'meta'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'meta'
)]
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
0 → 100644
View file @
e532679c
from
typing
import
List
,
Tuple
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
from
colossalai.fx.profiler.memory_utils
import
activation_size
from
colossalai.fx.profiler.opcount
import
flop_mapping
from
..constants
import
BCAST_FUNC_OP
,
NO_SAVE_ACTIVATION
from
..registry
import
meta_register
__all__
=
[
'binary_elementwise_meta_info'
]
@
meta_register
.
register
(
BCAST_FUNC_OP
)
def
binary_elementwise_meta_info
(
*
args
,
**
kwargs
)
->
Tuple
[
TrainCycleItem
,
TrainCycleItem
,
List
[
torch
.
Tensor
]]:
"""Meta information generator for binary elementwise operations
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
this behavior, it is critical for better memory estimation.
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_op_data
=
[
arg
for
arg
in
args
if
arg
.
type
!=
OperationDataType
.
OUTPUT
]
output_op_data
=
next
(
filter
(
lambda
arg
:
arg
.
type
==
OperationDataType
.
OUTPUT
,
args
))
# construct forward args for flop mapping
fwd_in_args
=
[
opdata
.
data
for
opdata
in
input_op_data
]
fwd_out_args
=
[
output_op_data
.
data
]
# calculate cost
# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
add
.
Tensor
](
fwd_in_args
,
fwd_out_args
)
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
# calculate memory cost
param_mem_cost
=
activation_size
([
arg
.
data
for
arg
in
input_op_data
if
arg
.
type
==
OperationDataType
.
PARAM
])
fwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
(
output_op_data
.
data
),
parameter
=
param_mem_cost
,
)
bwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
(
fwd_in_args
),
parameter
=
param_mem_cost
,
)
# total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
parameter
=
fwd_mem_cost
.
parameter
+
bwd_mem_cost
.
parameter
,
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[]
fwd_buffer
=
[]
fwd_out
=
[
torch
.
zeros_like
(
output_op_data
.
data
,
device
=
'meta'
)]
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
0 → 100644
View file @
e532679c
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
Union
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
TrainCycleItem
,
)
from
colossalai.fx.profiler.memory_utils
import
activation_size
from
colossalai.fx.profiler.opcount
import
flop_mapping
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
..registry
import
meta_register
__all__
=
[
'convnd_meta_info'
]
@
meta_register
.
register
(
torch
.
nn
.
Conv1d
)
@
meta_register
.
register
(
torch
.
nn
.
Conv2d
)
@
meta_register
.
register
(
torch
.
nn
.
Conv3d
)
@
meta_register
.
register
(
torch
.
nn
.
functional
.
conv1d
)
@
meta_register
.
register
(
torch
.
nn
.
functional
.
conv2d
)
@
meta_register
.
register
(
torch
.
nn
.
functional
.
conv3d
)
def
convnd_meta_info
(
*
args
,
**
kwargs
)
->
Tuple
[
TrainCycleItem
,
TrainCycleItem
,
List
[
torch
.
Tensor
]]:
"""torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator
The atens graph of torch.nn.Convnd with bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%convolution_backward_default : [#users=3] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None, [None, None, None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
The atens graph of torch.nn.Convnd without bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None], [None, None], [None, None], None, [None, None], None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%convolution_backward_default : [#users=2] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None], [None, None], [None, None], None, [None, None], None, [None, None, None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
has_bias
:
bool
=
False
input_tensor
=
args
[
0
].
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
if
len
(
args
)
==
4
:
weight_tensors
=
[
args
[
1
].
data
,
args
[
3
].
data
]
else
:
weight_tensors
=
[
args
[
1
].
data
]
# check if conv has bias
if
len
(
weight_tensors
)
>
1
:
has_bias
=
True
# bias tensor's shape only has one dimension
if
len
(
weight_tensors
[
0
].
shape
)
==
1
:
bias_tensor
,
weight_tensor
=
weight_tensors
else
:
weight_tensor
,
bias_tensor
=
weight_tensors
else
:
weight_tensor
=
weight_tensors
[
0
]
# construct input args for forward
fwd_args
=
[
None
]
*
9
# weight and input
fwd_args
[
0
]
=
input_tensor
fwd_args
[
1
]
=
weight_tensor
fwd_args
[
2
]
=
bias_tensor
if
has_bias
else
None
# transpose indicator should be set to False
fwd_args
[
6
]
=
False
# construct input args for backward
bwd_args
=
[
None
]
*
11
# weight and input
bwd_args
[
0
]
=
output_tensor
bwd_args
[
1
]
=
input_tensor
bwd_args
[
2
]
=
weight_tensor
bwd_args
[
-
1
]
=
[
True
,
True
,
True
]
if
has_bias
else
[
True
,
True
,
False
]
# calculate cost
# the fwd op with compute cost is convolution.default
# the bwd op with compute cost is convolution_backward.default
# calculate compute cost
fwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
convolution
.
default
](
fwd_args
,
(
output_tensor
,))
bwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
convolution_backward
.
default
](
bwd_args
,
(
input_tensor
,
weight_tensor
,
bias_tensor
))
if
has_bias
else
\
flop_mapping
[
torch
.
ops
.
aten
.
convolution_backward
.
default
](
bwd_args
,
(
input_tensor
,
weight_tensor
))
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
([
input_tensor
,
output_tensor
]),
parameter
=
activation_size
([
weight_tensor
,
bias_tensor
])
if
has_bias
else
activation_size
(
weight_tensor
),
temp
=
0
,
buffer
=
0
)
bwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
([
input_tensor
,
weight_tensor
,
bias_tensor
])
if
has_bias
else
activation_size
([
input_tensor
,
weight_tensor
]),
parameter
=
activation_size
([
weight_tensor
,
bias_tensor
])
if
has_bias
else
activation_size
(
weight_tensor
),
temp
=
0
,
buffer
=
0
)
# total cost is the sum of forward and backward cost
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
'meta'
)]
fwd_buffer
=
[]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'meta'
)]
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
0 → 100644
View file @
e532679c
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
Union
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
TrainCycleItem
,
)
from
colossalai.fx.profiler.memory_utils
import
activation_size
from
colossalai.fx.profiler.opcount
import
flop_mapping
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
..registry
import
meta_register
__all__
=
[
'linear_meta_info'
]
@
meta_register
.
register
(
torch
.
nn
.
functional
.
linear
)
@
meta_register
.
register
(
torch
.
nn
.
Linear
)
def
linear_meta_info
(
*
args
,
**
kwargs
)
->
Tuple
[
TrainCycleItem
,
TrainCycleItem
,
List
[
torch
.
Tensor
]]:
"""torch.nn.Linear & torch.nn.functional.linear meta info generator
NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator,
but we will hold the bias mechanism in the linear metainfo generator for future use.
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
%zeros_like_default : [#users=3] = call_function[target=torch.ops.aten.zeros_like.default](args = (%addmm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
%sum_dim_int_list : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%zeros_like_default, [None], None), kwargs = {})
%view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_dim_int_list, [None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%view_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
The one without bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%input_2, None), kwargs = {})
%zeros_like_default : [#users=2] = call_function[target=torch.ops.aten.zeros_like.default](args = (%mm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
%mm_default_2 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default_2,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
"""
has_bias
:
bool
=
False
input_tensor
=
args
[
0
].
data
output_tensor
=
args
[
2
].
data
if
len
(
args
)
==
4
:
weight_tensors
=
[
args
[
1
].
data
,
args
[
3
].
data
]
else
:
weight_tensors
=
[
args
[
1
].
data
]
# process the dimension of input and output
if
len
(
input_tensor
.
shape
)
>
2
:
input_tensor
:
torch
.
Tensor
input_tensor
=
input_tensor
.
view
(
-
1
,
input_tensor
.
shape
[
-
1
])
if
len
(
output_tensor
.
shape
)
>
2
:
output_tensor
:
torch
.
Tensor
output_tensor
=
output_tensor
.
view
(
-
1
,
output_tensor
.
shape
[
-
1
])
if
len
(
weight_tensors
)
>
1
:
has_bias
=
True
if
len
(
weight_tensors
[
0
].
shape
)
==
2
:
weight_tensor
,
bias_tensor
=
weight_tensors
else
:
bias_tensor
,
weight_tensor
=
weight_tensors
else
:
weight_tensor
=
weight_tensors
[
0
]
if
has_bias
:
# calculate cost with bias
# the fwd op with compute cost is addmm
# the bwd op with compute cost is mm * 2 and sum.dim_IntList
# calculate compute cost
fwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
addmm
.
default
](
[
bias_tensor
,
input_tensor
,
torch
.
transpose
(
weight_tensor
,
0
,
1
)],
(
output_tensor
,))
bwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
mm
.
default
]([
output_tensor
,
weight_tensor
],
(
input_tensor
,))
+
\
flop_mapping
[
torch
.
ops
.
aten
.
mm
.
default
]([
torch
.
transpose
(
output_tensor
,
0
,
1
),
input_tensor
],
(
weight_tensor
,))
+
\
flop_mapping
[
torch
.
ops
.
aten
.
sum
.
dim_IntList
]([
output_tensor
],
(
bias_tensor
,))
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
([
input_tensor
,
output_tensor
]),
parameter
=
activation_size
([
weight_tensor
,
bias_tensor
]),
temp
=
0
,
buffer
=
0
)
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
bwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
([
input_tensor
,
weight_tensor
,
bias_tensor
]),
parameter
=
activation_size
([
weight_tensor
,
bias_tensor
]),
temp
=
0
,
buffer
=
0
)
# total cost is to sum the forward and backward cost
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
else
:
# calculate cost without bias
# the fwd op with compute cost is mm
# the bwd op with compute cost is mm * 2
# calculate compute cost
fwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
mm
.
default
](
[
input_tensor
,
torch
.
transpose
(
weight_tensor
,
0
,
1
)],
(
output_tensor
,))
bwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
mm
.
default
]([
output_tensor
,
weight_tensor
],
(
input_tensor
,))
+
\
flop_mapping
[
torch
.
ops
.
aten
.
mm
.
default
]([
torch
.
transpose
(
output_tensor
,
0
,
1
),
input_tensor
],
(
weight_tensor
,))
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
([
input_tensor
,
output_tensor
]),
parameter
=
activation_size
(
weight_tensor
),
temp
=
0
,
buffer
=
0
)
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
bwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
([
input_tensor
,
weight_tensor
]),
parameter
=
activation_size
(
weight_tensor
),
temp
=
0
,
buffer
=
0
)
# total cost is to sum the forward and backward cost
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
'meta'
)]
fwd_buffer
=
[]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'meta'
)]
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
0 → 100644
View file @
e532679c
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
Union
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
TrainCycleItem
,
)
from
colossalai.fx.profiler.memory_utils
import
activation_size
from
colossalai.fx.profiler.opcount
import
flop_mapping
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
..registry
import
meta_register
__all__
=
[
'batchnormnd_meta_info'
]
@
meta_register
.
register
(
torch
.
nn
.
BatchNorm1d
)
@
meta_register
.
register
(
torch
.
nn
.
BatchNorm2d
)
@
meta_register
.
register
(
torch
.
nn
.
BatchNorm3d
)
def
batchnormnd_meta_info
(
*
args
,
**
kwargs
)
->
Tuple
[
TrainCycleItem
,
TrainCycleItem
,
List
[
torch
.
Tensor
]]:
"""BatchNorm1d, BatchNorm2d, BatchNorm3d, meta info generator
The aten graph of BatchNorm2d is like
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%cudnn_batch_norm_default : [#users=4] = call_function[target=torch.ops.aten.cudnn_batch_norm.default](args = (%input_2, None, None, None, None, None, None, None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%cudnn_batch_norm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_default,), kwargs = {})
%cudnn_batch_norm_backward_default : [#users=3] = call_function[target=torch.ops.aten.cudnn_batch_norm_backward.default](args = (%detach_default, %zeros_like_default, None, None, None, %detach_default_1, %detach_default_2, None, %detach_default_3), kwargs = {})
%detach_default_4 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_5 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_4,), kwargs = {})
%detach_default_6 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_7 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_6,), kwargs = {})
%detach_default_8 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%cudnn_batch_norm_backward_default,), kwargs = {})
%detach_default_9 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_8,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor
=
args
[
0
].
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
weight_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"weight"
,
args
)).
data
bias_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"bias"
,
args
)).
data
mean_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"running_mean"
,
args
)).
data
var_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"running_var"
,
args
)).
data
num_batch
=
next
(
filter
(
lambda
x
:
x
.
name
==
"num_batches_tracked"
,
args
)).
data
# construct fwd args
# the fwd inputs are input, weight, bias, running_mean, running_var and some other args
# indicating the status of the module
# the fwd outputs are output, saved mean, saved inv std and num batches tracked
fwd_in_args
=
[
input_tensor
,
weight_tensor
,
bias_tensor
,
mean_tensor
,
var_tensor
,
True
,
0.1
,
1e-5
]
fwd_out_args
=
[
output_tensor
,
mean_tensor
,
var_tensor
,
num_batch
]
# construct bwd args
# the bwd inputs are upstream grad, input, weight, running_mean, running_var, saved mean,
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args
=
[
output_tensor
,
output_tensor
,
weight_tensor
,
mean_tensor
,
var_tensor
,
mean_tensor
,
var_tensor
,
1e-5
,
num_batch
]
bwd_out_args
=
[
input_tensor
,
weight_tensor
,
bias_tensor
]
# calculate cost
fwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
cudnn_batch_norm
.
default
](
fwd_in_args
,
fwd_out_args
)
bwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
cudnn_batch_norm_backward
.
default
](
bwd_in_args
,
bwd_out_args
)
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
([
input_tensor
,
output_tensor
,
mean_tensor
,
var_tensor
]),
parameter
=
activation_size
([
weight_tensor
,
bias_tensor
]),
temp
=
0
,
buffer
=
activation_size
([
mean_tensor
,
var_tensor
]))
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
bwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
([
input_tensor
]),
parameter
=
activation_size
([
weight_tensor
,
bias_tensor
]),
temp
=
activation_size
([
mean_tensor
,
var_tensor
]),
buffer
=
activation_size
([
mean_tensor
,
var_tensor
]))
# total cost is the sum of forward and backward cost
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
'meta'
)]
fwd_buffer
=
[
torch
.
zeros_like
(
mean_tensor
,
device
=
'meta'
),
torch
.
zeros_like
(
var_tensor
,
device
=
'meta'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'meta'
)]
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
0 → 100644
View file @
e532679c
from
typing
import
List
,
Tuple
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
from
colossalai.fx.profiler.memory_utils
import
activation_size
from
colossalai.fx.profiler.opcount
import
flop_mapping
from
..registry
import
meta_register
__all__
=
[
"avgpool_meta_info"
,
"maxpool_meta_info"
]
@
meta_register
.
register
(
torch
.
nn
.
AdaptiveAvgPool1d
)
@
meta_register
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
@
meta_register
.
register
(
torch
.
nn
.
AdaptiveAvgPool3d
)
@
meta_register
.
register
(
torch
.
flatten
)
def
avgpool_meta_info
(
*
args
,
**
kwargs
)
->
Tuple
[
TrainCycleItem
,
TrainCycleItem
,
List
[
torch
.
Tensor
]]:
"""Meta info for AdaptiveAvgPool
The aten graph of AdaptiveAvgPool is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%_adaptive_avg_pool2d_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d.default](args = (%input_2, [None, None]), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%_adaptive_avg_pool2d_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%_adaptive_avg_pool2d_backward_default : [#users=1] = call_function[target=torch.ops.aten._adaptive_avg_pool2d_backward.default](args = (%zeros_like_default, %detach_default), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_adaptive_avg_pool2d_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor
=
args
[
0
].
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
is_inplace
=
kwargs
.
get
(
"inplace"
,
False
)
# construct forward args for flop mapping
fwd_in_args
=
[
input_tensor
]
fwd_out_args
=
[
output_tensor
]
# construct backward args for flop mapping
bwd_in_args
=
[
output_tensor
]
bwd_out_args
=
[
input_tensor
]
# calculate cost
# the fwd op with compute cost is _adaptive_avg_pool2d.default
# the bwd op with compute cost is _adaptive_avg_pool2d_backward.default
# calculate compute cost
fwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
_adaptive_avg_pool2d
.
default
](
fwd_in_args
,
fwd_out_args
)
bwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
_adaptive_avg_pool2d_backward
.
default
](
bwd_in_args
,
bwd_out_args
)
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
# calculate memory cost
fwd_mem_cost
=
MemoryCost
()
if
is_inplace
else
MemoryCost
(
activation
=
activation_size
(
output_tensor
))
bwd_mem_cost
=
MemoryCost
()
if
is_inplace
else
MemoryCost
(
activation
=
activation_size
(
input_tensor
))
# total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
)
mem_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[]
fwd_buffer
=
[]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'meta'
)]
return
compute_cost
,
mem_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
@
meta_register
.
register
(
torch
.
nn
.
MaxPool1d
)
@
meta_register
.
register
(
torch
.
nn
.
MaxPool2d
)
@
meta_register
.
register
(
torch
.
nn
.
MaxPool3d
)
def
maxpool_meta_info
(
*
args
,
**
kwargs
)
->
Tuple
[
TrainCycleItem
,
TrainCycleItem
,
List
[
torch
.
Tensor
]]:
"""Meta info for MaxPool
The aten graph of MaxPool is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%max_pool2d_with_indices_default : [#users=2] = call_function[target=torch.ops.aten.max_pool2d_with_indices.default](args = (%input_2, [None, None], [None, None]), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%max_pool2d_with_indices_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_default,), kwargs = {})
%max_pool2d_with_indices_backward_default : [#users=1] = call_function[target=torch.ops.aten.max_pool2d_with_indices_backward.default](args = (%zeros_like_default, %detach_default, [None, None], [None, None], [None, None], [None, None], None, %detach_default_1), kwargs = {})
%detach_default_2 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%max_pool2d_with_indices_backward_default,), kwargs = {})
%detach_default_3 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_2,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
ARG
,
args
)).
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
# construct forward args for flop mapping
fwd_in_args
=
[
input_tensor
]
fwd_out_args
=
[
output_tensor
]
# construct backward args for flop mapping
bwd_in_args
=
[
output_tensor
]
bwd_out_args
=
[
input_tensor
]
# construct index matrix
index_matrix
=
torch
.
zeros_like
(
output_tensor
,
device
=
"meta"
,
dtype
=
torch
.
int64
)
# calculate cost
# the fwd op with compute cost is max_pool2d_with_indices.default
# the bwd op with compute cost is max_pool2d_with_indices_backward.default
# calculate compute cost
fwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
max_pool2d_with_indices
.
default
](
fwd_in_args
,
fwd_out_args
)
bwd_compute_cost
=
flop_mapping
[
torch
.
ops
.
aten
.
max_pool2d_with_indices_backward
.
default
](
bwd_in_args
,
bwd_out_args
)
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
# calculate memory cost
# NOTE: the index matrix will be discarded in backward phase
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
([
input_tensor
,
output_tensor
,
index_matrix
]))
# temp memory for backward is the index matrix to be discarded
bwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
(
input_tensor
)
-
activation_size
(
index_matrix
),
temp
=
activation_size
(
index_matrix
))
# total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
temp
=
bwd_mem_cost
.
temp
)
mem_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
'meta'
)]
fwd_buffer
=
[
torch
.
zeros_like
(
index_matrix
,
device
=
'meta'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'meta'
)]
return
compute_cost
,
mem_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
colossalai/auto_parallel/meta_profiler/metainfo.py
0 → 100644
View file @
e532679c
from
typing
import
Callable
,
List
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
TrainCycleItem
,
)
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.constants
import
INPLACE_MODULE
,
INPLACE_OPS
,
NO_SAVE_ACTIVATION
from
.registry
import
meta_register
__all__
=
[
'MetaInfo'
]
class
MetaInfo
:
"""MetaInfo class
This class is used to store meta info based on sharding strategy and the given
target function.
"""
def
__init__
(
self
,
strategy
:
ShardingStrategy
=
None
,
target
:
Callable
=
None
)
->
None
:
# compute cost of forward and backward computation
self
.
compute_cost
:
TrainCycleItem
# compute memory cost of forward and backward phase
self
.
memory_cost
:
TrainCycleItem
# list of input tensors
self
.
fwd_in
:
List
[
torch
.
Tensor
]
# list of buffer tensors
self
.
fwd_buffer
:
List
[
torch
.
Tensor
]
# list of output tensors
self
.
fwd_out
:
List
[
torch
.
Tensor
]
# sharding strategy
self
.
_strategy
=
strategy
# target function
self
.
_target
=
target
# compute metainfo if possible
if
self
.
_strategy
is
not
None
and
self
.
_target
is
not
None
:
self
.
compute_metainfo
()
@
property
def
strategy
(
self
)
->
ShardingStrategy
:
return
self
.
_strategy
@
property
def
target
(
self
)
->
Callable
:
return
self
.
_target
@
strategy
.
setter
def
strategy
(
self
,
strategy
:
ShardingStrategy
)
->
None
:
self
.
_strategy
=
strategy
if
self
.
_strategy
is
not
None
and
self
.
_target
is
not
None
:
self
.
compute_metainfo
()
@
target
.
setter
def
target
(
self
,
target
:
Callable
)
->
None
:
self
.
_target
=
target
if
self
.
_strategy
is
not
None
and
self
.
_target
is
not
None
:
self
.
compute_metainfo
()
def
compute_sharded_opdata
(
self
,
operation_data
:
OperationData
,
sharding_spec
:
ShardingSpec
)
->
torch
.
Tensor
:
"""
Compute sharded opdata based on the given data and sharding spec.
"""
return
OperationData
(
name
=
operation_data
.
name
,
data
=
torch
.
zeros
(
sharding_spec
.
get_sharded_shape_per_device
(),
device
=
"meta"
),
type
=
operation_data
.
type
,
logical_shape
=
operation_data
.
logical_shape
)
def
compute_metainfo
(
self
):
"""
Compute meta info based on sharding strategy and the given target function.
"""
assert
meta_register
.
has
(
self
.
_target
.
__class__
)
or
meta_register
.
has
(
self
.
_target
),
\
f
"Meta info for
{
self
.
_target
}
is not registered."
if
meta_register
.
has
(
self
.
_target
.
__class__
):
# module
meta_func
=
meta_register
.
get
(
self
.
_target
.
__class__
)
# check whether the target in the list that we don't need to save activation
save_fwd_in
=
self
.
_target
.
__class__
not
in
NO_SAVE_ACTIVATION
else
:
# function
meta_func
=
meta_register
.
get
(
self
.
_target
)
# check whether the target in the list that we don't need to save activation
save_fwd_in
=
self
.
_target
.
__class__
not
in
NO_SAVE_ACTIVATION
# construct args for meta_func
args
=
[
self
.
compute_sharded_opdata
(
k
,
v
)
for
k
,
v
in
self
.
_strategy
.
sharding_specs
.
items
()]
# construct kwargs
if
self
.
target
in
INPLACE_MODULE
:
kwargs
=
{
'inplace'
:
self
.
target
.
inplace
}
elif
self
.
target
in
INPLACE_OPS
:
kwargs
=
{
'inplace'
:
True
}
else
:
kwargs
=
{
'inplace'
:
False
}
# compute metainfo with meta_func
self
.
compute_cost
,
self
.
memory_cost
,
self
.
fwd_in
,
self
.
fwd_buffer
,
self
.
fwd_out
=
meta_func
(
*
args
,
**
kwargs
)
# process corner case for NO_SAVE_ACTIVATION
if
not
save_fwd_in
:
self
.
fwd_in
=
[]
colossalai/auto_parallel/meta_profiler/registry.py
0 → 100644
View file @
e532679c
__all__
=
[
'Registry'
]
class
Registry
:
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
store
=
{}
def
register
(
self
,
source
):
def
wrapper
(
func
):
if
isinstance
(
source
,
(
list
,
tuple
)):
# support register a list of items for this func
for
element
in
source
:
self
.
store
[
element
]
=
func
else
:
self
.
store
[
source
]
=
func
return
func
return
wrapper
def
get
(
self
,
source
):
assert
source
in
self
.
store
,
f
'
{
source
}
not found in the
{
self
.
name
}
registry'
target
=
self
.
store
[
source
]
return
target
def
has
(
self
,
source
):
return
source
in
self
.
store
meta_register
=
Registry
(
'meta'
)
Prev
1
2
3
4
5
6
7
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment