Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b42d3d28
Unverified
Commit
b42d3d28
authored
Mar 07, 2023
by
Super Daniel
Committed by
GitHub
Mar 07, 2023
Browse files
[fx] remove depreciated algorithms. (#2312) (#2313)
parent
55dcd305
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
0 additions
and
1970 deletions
+0
-1970
colossalai/fx/passes/algorithms/__init__.py
colossalai/fx/passes/algorithms/__init__.py
+0
-4
colossalai/fx/passes/algorithms/build_c_ext.py
colossalai/fx/passes/algorithms/build_c_ext.py
+0
-15
colossalai/fx/passes/algorithms/ckpt_solver_chen.py
colossalai/fx/passes/algorithms/ckpt_solver_chen.py
+0
-98
colossalai/fx/passes/algorithms/ckpt_solver_pofo.py
colossalai/fx/passes/algorithms/ckpt_solver_pofo.py
+0
-537
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
+0
-436
colossalai/fx/passes/algorithms/dynamic_programs.c
colossalai/fx/passes/algorithms/dynamic_programs.c
+0
-516
colossalai/fx/passes/algorithms/linearize.py
colossalai/fx/passes/algorithms/linearize.py
+0
-94
colossalai/fx/passes/algorithms/operation.py
colossalai/fx/passes/algorithms/operation.py
+0
-270
No files found.
colossalai/fx/passes/algorithms/__init__.py
deleted
100644 → 0
View file @
55dcd305
from
.ckpt_solver_chen
import
chen_greedy
from
.linearize
import
linearize
from
.ckpt_solver_rotor
import
solver_rotor
from
.ckpt_solver_pofo
import
solver_pofo
colossalai/fx/passes/algorithms/build_c_ext.py
deleted
100644 → 0
View file @
55dcd305
from
setuptools
import
setup
,
Extension
import
os
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
ext_modules
=
[
Extension
(
'dynamic_programs_C_version'
,
sources
=
[
os
.
path
.
join
(
this_dir
,
'dynamic_programs.c'
)],
)]
setup
(
name
=
'rotor c extension'
,
version
=
'0.1'
,
description
=
'rotor c extension for faster dp computing'
,
ext_modules
=
ext_modules
,
)
colossalai/fx/passes/algorithms/ckpt_solver_chen.py
deleted
100644 → 0
View file @
55dcd305
import
math
from
typing
import
List
,
Set
,
Tuple
import
torch
from
torch.fx
import
GraphModule
,
Node
from
colossalai.fx.profiler
import
calculate_fwd_in
,
calculate_fwd_tmp
__all__
=
[
'chen_greedy'
]
CKPT_OP
=
[
'call_module'
,
'call_method'
,
'call_function'
,
'get_attr'
]
def
_all_potential_ckpt_nodes
(
gm
:
GraphModule
)
->
List
:
"""
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
"""
def
is_sink
():
"""
If we can free all memories when executing a certain node, it is a sink.
"""
return
not
sum
((
v
for
k
,
v
in
deps
.
items
()))
deps
=
{}
ckpt_nodes
=
[]
for
n
in
gm
.
graph
.
nodes
:
for
n_par
in
n
.
_input_nodes
:
deps
[
n_par
]
-=
1
# free memory and dependencies
# We can only put act_ckpt on these nodes
if
n
.
op
in
CKPT_OP
and
is_sink
():
ckpt_nodes
.
append
(
n
)
deps
[
n
]
=
len
(
n
.
users
)
# add dependencies for future executions
return
ckpt_nodes
def
chen_greedy
(
gm
:
GraphModule
)
->
GraphModule
:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
model = resnet18()
input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
gm = chen_greedy(gm)
Args:
gm (GraphModule): The module to add checkpoints
"""
def
grid_search
(
num_grids
:
int
=
6
)
->
Set
:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
"""
_
,
b_approx
=
run_chen_greedy
(
0
)
b_min
,
b_max
=
math
.
floor
(
b_approx
/
math
.
sqrt
(
2
)),
math
.
ceil
(
b_approx
*
math
.
sqrt
(
2
))
b_opt
=
math
.
inf
for
b
in
range
(
b_min
,
b_max
,
(
b_max
-
b_min
)
//
num_grids
):
ckpt_intv
,
b_approx
=
run_chen_greedy
(
b
)
if
b_approx
<
b_opt
:
b_opt
=
b_approx
ckpt_opt
=
ckpt_intv
return
ckpt_opt
def
run_chen_greedy
(
b
:
int
=
0
)
->
Tuple
[
Set
,
int
]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt_nodes
=
_all_potential_ckpt_nodes
(
gm
)
ckpt_intv
=
[]
temp
=
0
x
=
0
y
=
0
prev_idx
=
2
for
(
idx
,
n
)
in
enumerate
(
gm
.
graph
.
nodes
):
n
:
Node
temp
+=
calculate_fwd_in
(
n
)
+
calculate_fwd_tmp
(
n
)
y
=
max
(
y
,
temp
)
if
temp
>
b
and
n
in
ckpt_nodes
:
x
+=
calculate_fwd_in
(
n
)
temp
=
0
ckpt_intv
.
append
((
prev_idx
,
idx
+
1
))
prev_idx
=
idx
+
1
return
ckpt_intv
,
math
.
floor
(
math
.
sqrt
(
x
*
y
))
gm
.
graph
.
lint
()
# make sure nodes are in topological order
ckpt
=
grid_search
(
num_grids
=
6
)
node_list
=
list
(
gm
.
graph
.
nodes
)
for
i
,
seg
in
enumerate
(
ckpt
):
for
idx
in
range
(
*
seg
):
n
=
node_list
[
idx
]
if
n
.
op
in
CKPT_OP
:
setattr
(
n
,
'activation_checkpoint'
,
i
)
gm
.
recompile
()
return
gm
colossalai/fx/passes/algorithms/ckpt_solver_pofo.py
deleted
100644 → 0
View file @
55dcd305
This diff is collapsed.
Click to expand it.
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
deleted
100644 → 0
View file @
55dcd305
import
math
import
sys
from
typing
import
List
,
Tuple
from
torch.fx
import
Node
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
_find_nested_ckpt_regions
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.profiler
import
activation_size
,
calculate_fwd_out
,
calculate_fwd_tmp
,
parameter_size
from
colossalai.logging
import
get_dist_logger
from
.linearize
import
linearize
from
.operation
import
Backward
,
Chain
,
ForwardCheck
,
ForwardEnable
,
ForwardNograd
,
Function
,
Loss
,
Sequence
# global vairable to indicate whether the solver is failed
SOLVER_FAILED
=
False
# this is the python compute table code from rotor
# https://gitlab.inria.fr/hiepacs/rotor
# paper link: https://hal.inria.fr/hal-02352969
def
_compute_table
(
chain
:
Chain
,
mmax
)
->
Tuple
:
"""Returns the optimal table: a tuple containing:
Opt[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
(False, j) if the optimal choice is a leaf checkpoint of length j
The computation uses dynamic programming"""
fw
=
chain
.
fweight
+
[
0
]
## forward time
bw
=
chain
.
bweight
## backward time, not used
cw
=
chain
.
cweight
+
[
0
]
## size of x (and of y)
cbw
=
chain
.
cbweight
+
[
0
]
## size of xbar
fwd_mem_tmp
=
chain
.
fwd_mem_tmp
+
[
0
]
bwd_mem_tmp
=
chain
.
bwd_mem_tmp
+
[
0
]
# Build table
opt
=
[[{}
for
_
in
range
(
chain
.
length
+
1
)]
for
_
in
range
(
mmax
+
1
)]
what
=
[[{}
for
_
in
range
(
chain
.
length
+
1
)]
for
_
in
range
(
mmax
+
1
)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for
m
in
range
(
mmax
+
1
):
for
i
in
range
(
chain
.
length
+
1
):
#lmax-lmin = 0
limit
=
max
(
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
fwd_mem_tmp
[
i
],
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
bwd_mem_tmp
[
i
])
if
m
>=
limit
:
## Equation (1)
opt
[
m
][
i
][
i
]
=
fw
[
i
]
+
bw
[
i
]
else
:
opt
[
m
][
i
][
i
]
=
float
(
"inf"
)
# Compute everything
for
m
in
range
(
mmax
+
1
):
for
d
in
range
(
1
,
chain
.
length
+
1
):
for
i
in
range
(
chain
.
length
+
1
-
d
):
# for idx in range(i+1, chain.length + 1):
idx
=
i
+
d
mmin
=
cw
[
idx
+
1
]
+
cw
[
i
+
1
]
+
fwd_mem_tmp
[
i
]
if
idx
>
i
+
1
:
mmin
=
max
(
mmin
,
cw
[
idx
+
1
]
+
max
(
cw
[
j
]
+
cw
[
j
+
1
]
+
fwd_mem_tmp
[
j
]
for
j
in
range
(
i
+
1
,
idx
)))
if
m
<
mmin
:
opt
[
m
][
i
][
idx
]
=
float
(
"inf"
)
else
:
leaf_checkpoints
=
[(
j
,
sum
(
fw
[
i
:
j
])
+
opt
[
m
-
cw
[
j
]][
j
][
idx
]
+
opt
[
m
][
i
][
j
-
1
])
for
j
in
range
(
i
+
1
,
idx
+
1
)
if
m
>=
cw
[
j
]]
if
leaf_checkpoints
:
best_leaf
=
min
(
leaf_checkpoints
,
key
=
lambda
t
:
t
[
1
])
else
:
best_leaf
=
None
if
m
>=
cbw
[
i
+
1
]:
chain_checkpoint
=
opt
[
m
][
i
][
i
]
+
opt
[
m
-
cbw
[
i
+
1
]][
i
+
1
][
idx
]
else
:
chain_checkpoint
=
float
(
"inf"
)
if
best_leaf
and
best_leaf
[
1
]
<=
chain_checkpoint
:
opt
[
m
][
i
][
idx
]
=
best_leaf
[
1
]
what
[
m
][
i
][
idx
]
=
(
False
,
best_leaf
[
0
])
else
:
opt
[
m
][
i
][
idx
]
=
chain_checkpoint
what
[
m
][
i
][
idx
]
=
(
True
,)
return
(
opt
,
what
)
def
_rec
(
chain
:
Chain
,
lmin
,
lmax
,
cmem
,
opt_table
):
""" chain : the class describing the AC graph
lmin : index of the first forward to execute
lmax : upper bound index of the last forward to execute (not included)
cmem : number of available memory slots
Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]"""
if
cmem
<=
0
:
raise
ValueError
(
"Can not process a chain with negative memory {cmem}"
.
format
(
cmem
=
cmem
))
opt
,
what
=
opt_table
sequence
=
Sequence
(
Function
(
"Persistent"
,
lmax
-
lmin
,
cmem
))
if
opt
[
cmem
][
lmin
][
lmax
]
==
float
(
"inf"
):
# using logger to annonce that the solver is failed
logger
=
get_dist_logger
()
logger
.
info
(
"Can not process this chain from index {lmin} to {lmax} with memory {cmem}"
.
format
(
lmin
=
lmin
,
lmax
=
lmax
,
cmem
=
cmem
))
# set global indicater SOLVER_FAILED to True
global
SOLVER_FAILED
SOLVER_FAILED
=
True
return
sequence
if
lmin
==
lmax
:
if
lmin
==
chain
.
length
:
sequence
.
insert
(
Loss
())
else
:
sequence
.
insert
(
ForwardEnable
(
lmin
))
sequence
.
insert
(
Backward
(
lmin
))
return
sequence
if
what
[
cmem
][
lmin
][
lmax
][
0
]:
sequence
.
insert
(
ForwardEnable
(
lmin
))
sequence
.
insert_sequence
(
_rec
(
chain
,
lmin
+
1
,
lmax
,
cmem
-
chain
.
cbweight
[
lmin
+
1
],
opt_table
))
sequence
.
insert
(
Backward
(
lmin
))
else
:
j
=
what
[
cmem
][
lmin
][
lmax
][
1
]
sequence
.
insert
(
ForwardCheck
(
lmin
))
for
k
in
range
(
lmin
+
1
,
j
):
sequence
.
insert
(
ForwardNograd
(
k
))
sequence
.
insert_sequence
(
_rec
(
chain
,
j
,
lmax
,
cmem
-
chain
.
cweight
[
j
],
opt_table
))
sequence
.
insert_sequence
(
_rec
(
chain
,
lmin
,
j
-
1
,
cmem
,
opt_table
))
return
sequence
def
_fwd_xbar
(
node
:
List
[
Node
])
->
int
:
"""Get the forward xbar of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: xbar size, unit Byte
"""
xbar
=
0
for
n
in
node
:
xbar
+=
calculate_fwd_tmp
(
n
)
+
calculate_fwd_out
(
n
)
return
xbar
def
_fwd_time
(
node
:
List
[
Node
])
->
int
:
"""Get the foward time of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: foward time, extimated by flops count
"""
fwd_time
=
0
for
n
in
node
:
# minimum flop count is needed
fwd_time
+=
max
(
n
.
meta
[
'fwd_flop'
],
1
)
return
fwd_time
def
_bwd_time
(
node
:
List
[
Node
])
->
int
:
"""Get the backward time of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: backward time, extimated by flops count
"""
bwd_time
=
0
for
n
in
node
:
# minimum flop count is needed
bwd_time
+=
max
(
n
.
meta
[
'bwd_flop'
],
1
)
return
bwd_time
def
_get_fwd_mem_tmp
(
node
:
List
[
Node
])
->
int
:
"""Get the forward temp memory of a node
This could be done by subtracting the saved activation from all output of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: forward temp memory, unit Byte
"""
n
=
node
[
-
1
]
return
activation_size
(
n
.
meta
[
'fwd_out'
])
-
calculate_fwd_out
(
n
)
def
_get_bwd_mem_tmp
(
node
:
List
[
Node
])
->
int
:
"""Get the backward temp memory of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: backward temp memory, unit Byte
"""
def
_get_deps_size
():
deps_size
=
0
for
k
,
v
in
deps
.
items
():
k
:
Node
if
v
>
0
:
deps_size
+=
k
.
meta
[
'bwd_mem_out'
]
if
v
==
float
(
'-inf'
):
deps_size
-=
calculate_fwd_tmp
(
k
)
+
calculate_fwd_out
(
k
)
return
deps_size
bwd_mem_tmp
=
0
deps
=
{}
for
n
in
reversed
(
node
):
deps
[
n
]
=
len
(
n
.
all_input_nodes
)
bwd_mem_tmp
=
max
(
bwd_mem_tmp
,
_get_deps_size
()
+
n
.
meta
[
'bwd_mem_tmp'
])
for
child
in
n
.
users
:
if
child
in
deps
:
deps
[
child
]
-=
1
if
deps
[
child
]
<=
0
:
deps
[
child
]
=
float
(
'-inf'
)
# free
return
bwd_mem_tmp
def
_construct_chain
(
node_list
:
List
[
List
[
Node
]],
input
)
->
Chain
:
fwd_time
=
[]
bwd_time
=
[]
xbar_sizes
=
[
activation_size
(
input
)]
x_sizes
=
[
activation_size
(
input
)]
tmp_fwd
=
[]
tmp_bwd
=
[]
for
idx
,
node
in
enumerate
(
node_list
):
fwd_time
.
append
(
_fwd_time
(
node
))
bwd_time
.
append
(
_bwd_time
(
node
))
x_sizes
.
append
(
calculate_fwd_out
(
node
[
-
1
]))
xbar_sizes
.
append
(
max
(
x_sizes
[
-
1
],
_fwd_xbar
(
node
)))
tmp_fwd
.
append
(
_get_fwd_mem_tmp
(
node
))
tmp_bwd
.
append
(
_get_bwd_mem_tmp
(
node
))
bwd_time
.
append
(
0
)
# currently we view loss backward temp as zero
tmp_bwd
.
append
(
0
)
return
Chain
(
fwd_time
,
bwd_time
,
x_sizes
,
xbar_sizes
,
tmp_fwd
,
tmp_bwd
)
def
_annotate_from_sequence
(
sequence
:
Sequence
,
node_list
:
List
[
List
[
Node
]]):
op_list
=
sequence
.
list_operations
()
loss_op
=
next
(
op
for
op
in
op_list
if
isinstance
(
op
,
Loss
))
fwd_list
=
op_list
[:
op_list
.
index
(
loss_op
)]
bwd_list
=
op_list
[
op_list
.
index
(
loss_op
)
+
1
:]
ckpt_idx
=
0
in_ckpt
=
False
ckpt_region
=
[]
# forward annotation
for
idx
,
op
in
enumerate
(
fwd_list
,
0
):
if
in_ckpt
:
if
isinstance
(
op
,
ForwardNograd
):
ckpt_region
.
append
(
idx
)
elif
isinstance
(
op
,
ForwardEnable
):
in_ckpt
=
False
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
setattr
(
n
,
"activation_checkpoint"
,
[
ckpt_idx
])
ckpt_idx
+=
1
ckpt_region
=
[]
elif
isinstance
(
op
,
ForwardCheck
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
setattr
(
n
,
"activation_checkpoint"
,
[
ckpt_idx
])
ckpt_idx
+=
1
ckpt_region
=
[
idx
]
else
:
if
isinstance
(
op
,
ForwardCheck
):
in_ckpt
=
True
ckpt_region
.
append
(
idx
)
# annotate the backward if there is any nested activation checkpoint
in_recompute
=
False
for
op
in
bwd_list
:
if
in_recompute
:
if
isinstance
(
op
,
ForwardNograd
):
ckpt_region
.
append
(
op
.
index
)
elif
isinstance
(
op
,
ForwardEnable
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
activation_checkpoint
.
append
(
ckpt_idx
)
ckpt_idx
+=
1
ckpt_region
=
[]
elif
isinstance
(
op
,
ForwardCheck
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
activation_checkpoint
.
append
(
ckpt_idx
)
ckpt_idx
+=
1
ckpt_region
=
[
op
.
index
]
elif
isinstance
(
op
,
Backward
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
activation_checkpoint
.
append
(
ckpt_idx
)
in_recompute
=
False
else
:
if
not
isinstance
(
op
,
Backward
):
in_recompute
=
True
ckpt_idx
=
0
ckpt_region
=
[]
if
isinstance
(
op
,
ForwardCheck
):
ckpt_region
.
append
(
op
.
index
)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list
=
[]
for
node
in
node_list
:
op_list
+=
node
ckpt_regions
=
_find_nested_ckpt_regions
(
op_list
)
for
(
start_idx
,
end_idx
)
in
ckpt_regions
:
nested_length
=
max
(
len
(
op_list
[
idx
].
activation_checkpoint
)
for
idx
in
range
(
start_idx
,
end_idx
+
1
))
for
idx
in
range
(
start_idx
,
end_idx
+
1
):
op_list
[
idx
].
activation_checkpoint
+=
[
None
]
*
(
nested_length
-
len
(
op_list
[
idx
].
activation_checkpoint
))
def
solver_rotor
(
gm
:
ColoGraphModule
,
data
,
mem_limit
:
int
,
mem_slots
:
int
=
500
,
cnode
:
List
[
str
]
=
None
,
eps
:
float
=
0.0
,
force_python
:
bool
=
False
)
->
ColoGraphModule
:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp.
data (torch.Tensor): input data.
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.0
force_python (bool): force to use python version of dynamic programs
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
"""
# try to import C version solver if force_python is not set
logger
=
get_dist_logger
()
if
not
force_python
:
try
:
from
.dynamic_programs_C_version
import
persistent_compute_table
CVERSION
=
True
# build module if module not found
except
ModuleNotFoundError
:
import
os
import
subprocess
logger
.
info
(
"dynamic_programs_C_version hasn't been built! Building library..."
,
ranks
=
[
0
])
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
result
=
subprocess
.
Popen
(
[
f
"
{
sys
.
executable
}
"
,
f
"
{
os
.
path
.
join
(
this_dir
,
'build_c_ext.py'
)
}
"
,
"build_ext"
,
f
"--build-lib=
{
this_dir
}
"
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
)
if
result
.
wait
()
==
0
:
logger
.
info
(
"dynamic_programs_C_version has been built!"
,
ranks
=
[
0
])
from
.dynamic_programs_C_version
import
persistent_compute_table
CVERSION
=
True
else
:
logger
.
info
(
"dynamic_programs_C_version built failed! Using python version!"
,
ranks
=
[
0
])
CVERSION
=
False
else
:
CVERSION
=
False
# check if metainfoprop is done
if
any
(
len
(
node
.
meta
)
==
0
for
node
in
gm
.
graph
.
nodes
):
raise
RuntimeError
(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!"
)
# linearize the graph
node_list
=
linearize
(
gm
,
cnode
)
# construct chain
mem_unit
=
mem_limit
*
(
1.0
-
eps
)
//
mem_slots
chain
:
Chain
=
_construct_chain
(
node_list
,
data
)
chain
.
_discretize
(
mem_unit
)
# use C version if possible
if
CVERSION
and
not
force_python
:
logger
.
info
(
"Using C version rotor solver!"
,
ranks
=
[
0
])
opt_table
=
persistent_compute_table
(
chain
,
mem_slots
)
else
:
opt_table
=
_compute_table
(
chain
,
mem_slots
)
logger
.
info
(
"Using python version rotor solver!"
,
ranks
=
[
0
])
# found sequence
sequence
=
_rec
(
chain
,
0
,
chain
.
length
,
mem_slots
-
chain
.
cweight
[
0
],
opt_table
)
# if solver failed, we don't need to annotate the graph
if
not
SOLVER_FAILED
:
_annotate_from_sequence
(
sequence
,
node_list
)
# set __sequence__ attribute to GraphModule
if
SOLVER_FAILED
:
setattr
(
gm
,
"__sequence__"
,
None
)
else
:
setattr
(
gm
,
"__sequence__"
,
sequence
)
# set __opttable__ attribute to GraphModule
setattr
(
gm
,
"__opttable__"
,
opt_table
[
0
])
gm
.
recompile
()
return
gm
colossalai/fx/passes/algorithms/dynamic_programs.c
deleted
100644 → 0
View file @
55dcd305
#define PY_SSIZE_T_CLEAN
#include <Python.h>
long
*
PySequenceToLongArray
(
PyObject
*
pylist
)
{
if
(
!
(
pylist
&&
PySequence_Check
(
pylist
)))
return
NULL
;
Py_ssize_t
len
=
PySequence_Size
(
pylist
);
long
*
result
=
(
long
*
)
calloc
(
len
+
1
,
sizeof
(
long
));
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
++
i
)
{
PyObject
*
item
=
PySequence_GetItem
(
pylist
,
i
);
result
[
i
]
=
PyLong_AsLong
(
item
);
Py_DECREF
(
item
);
}
result
[
len
]
=
0
;
return
result
;
}
double
*
PySequenceToDoubleArray
(
PyObject
*
pylist
)
{
if
(
!
(
pylist
&&
PySequence_Check
(
pylist
)))
return
NULL
;
Py_ssize_t
len
=
PySequence_Size
(
pylist
);
double
*
result
=
(
double
*
)
calloc
(
len
+
1
,
sizeof
(
double
));
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
++
i
)
{
PyObject
*
item
=
PySequence_GetItem
(
pylist
,
i
);
result
[
i
]
=
PyFloat_AsDouble
(
item
);
Py_DECREF
(
item
);
}
result
[
len
]
=
0
;
return
result
;
}
long
*
getLongArray
(
PyObject
*
container
,
const
char
*
attributeName
)
{
PyObject
*
sequence
=
PyObject_GetAttrString
(
container
,
attributeName
);
long
*
result
=
PySequenceToLongArray
(
sequence
);
Py_DECREF
(
sequence
);
return
result
;
}
double
*
getDoubleArray
(
PyObject
*
container
,
const
char
*
attributeName
)
{
PyObject
*
sequence
=
PyObject_GetAttrString
(
container
,
attributeName
);
double
*
result
=
PySequenceToDoubleArray
(
sequence
);
Py_DECREF
(
sequence
);
return
result
;
}
static
PyObject
*
persistent_compute_table
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
chain_param
;
int
mmax
;
if
(
!
PyArg_ParseTuple
(
args
,
"Oi"
,
&
chain_param
,
&
mmax
))
return
NULL
;
double
*
fw
=
getDoubleArray
(
chain_param
,
"fweight"
);
if
(
!
fw
)
return
NULL
;
double
*
bw
=
getDoubleArray
(
chain_param
,
"bweight"
);
if
(
!
bw
)
return
NULL
;
long
*
cw
=
getLongArray
(
chain_param
,
"cweight"
);
if
(
!
cw
)
return
NULL
;
long
*
cbw
=
getLongArray
(
chain_param
,
"cbweight"
);
if
(
!
cbw
)
return
NULL
;
long
*
fwd_tmp
=
getLongArray
(
chain_param
,
"fwd_mem_tmp"
);
if
(
!
cbw
)
return
NULL
;
long
*
bwd_tmp
=
getLongArray
(
chain_param
,
"bwd_mem_tmp"
);
if
(
!
cbw
)
return
NULL
;
PyObject
*
chain_length_param
=
PyObject_GetAttrString
(
chain_param
,
"length"
);
if
(
!
chain_length_param
)
return
NULL
;
long
chain_length
=
PyLong_AsLong
(
chain_length_param
);
Py_DECREF
(
chain_length_param
);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double
*
opt
=
(
double
*
)
calloc
(
(
mmax
+
1
)
*
(
chain_length
+
1
)
*
(
chain_length
+
1
),
sizeof
(
double
));
#define WHAT(m, i, l) \
what[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
long
*
what
=
(
long
*
)
calloc
(
(
mmax
+
1
)
*
(
chain_length
+
1
)
*
(
chain_length
+
1
),
sizeof
(
long
));
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
// TODO: Can be optimized to remove the IF by reordering loops
if
((
m
>=
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
bwd_tmp
[
i
])
&&
(
m
>=
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
fwd_tmp
[
i
]))
OPT
(
m
,
i
,
i
)
=
fw
[
i
]
+
bw
[
i
];
else
OPT
(
m
,
i
,
i
)
=
INFINITY
;
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
d
=
1
;
d
<=
chain_length
;
++
d
)
{
for
(
long
i
=
0
;
i
<=
chain_length
-
d
;
++
i
)
{
long
idx
=
i
+
d
;
long
mmin
=
cw
[
idx
+
1
]
+
cw
[
i
+
1
]
+
fwd_tmp
[
i
];
if
(
idx
>
i
+
1
)
{
long
maxCostFWD
=
0
;
for
(
long
j
=
i
+
1
;
j
<
idx
;
j
++
)
{
maxCostFWD
=
fmaxl
(
maxCostFWD
,
cw
[
j
]
+
cw
[
j
+
1
]
+
fwd_tmp
[
j
]);
}
mmin
=
fmaxl
(
mmin
,
cw
[
idx
+
1
]
+
maxCostFWD
);
}
if
((
m
>=
mmin
))
{
long
bestLeaf
=
-
1
;
double
sumFw
=
0
;
double
bestLeafCost
=
INFINITY
;
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
/// i+1
for
(
long
j
=
i
+
1
;
j
<=
idx
;
++
j
)
{
sumFw
+=
fw
[
j
-
1
];
if
(
m
>=
cw
[
j
])
{
double
cost
=
sumFw
+
OPT
(
m
-
cw
[
j
],
j
,
idx
)
+
OPT
(
m
,
i
,
j
-
1
);
if
(
cost
<
bestLeafCost
)
{
bestLeafCost
=
cost
;
bestLeaf
=
j
;
}
}
}
double
chainCost
=
INFINITY
;
if
(
m
>=
cbw
[
i
+
1
])
chainCost
=
OPT
(
m
,
i
,
i
)
+
OPT
(
m
-
cbw
[
i
+
1
],
i
+
1
,
idx
);
if
(
bestLeafCost
<=
chainCost
)
{
OPT
(
m
,
i
,
idx
)
=
bestLeafCost
;
WHAT
(
m
,
i
,
idx
)
=
bestLeaf
;
}
else
{
OPT
(
m
,
i
,
idx
)
=
chainCost
;
WHAT
(
m
,
i
,
idx
)
=
-
1
;
}
}
else
OPT
(
m
,
i
,
idx
)
=
INFINITY
;
}
}
free
(
fw
);
free
(
bw
);
free
(
cw
);
free
(
cbw
);
free
(
fwd_tmp
);
free
(
bwd_tmp
);
PyObject
*
res_opt
=
PyList_New
(
mmax
+
1
);
PyObject
*
res_what
=
PyList_New
(
mmax
+
1
);
// Convert the result into Python world
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
{
PyObject
*
res_opt_m
=
PyList_New
(
chain_length
+
1
);
PyList_SET_ITEM
(
res_opt
,
m
,
res_opt_m
);
PyObject
*
res_what_m
=
PyList_New
(
chain_length
+
1
);
PyList_SET_ITEM
(
res_what
,
m
,
res_what_m
);
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
{
PyObject
*
res_opt_m_i
=
PyDict_New
();
PyList_SET_ITEM
(
res_opt_m
,
i
,
res_opt_m_i
);
PyObject
*
res_what_m_i
=
PyDict_New
();
PyList_SET_ITEM
(
res_what_m
,
i
,
res_what_m_i
);
for
(
long
l
=
i
;
l
<=
chain_length
;
++
l
)
{
PyObject
*
res_l
=
PyLong_FromLong
(
l
);
PyObject
*
res_opt_m_i_l
=
PyFloat_FromDouble
(
OPT
(
m
,
i
,
l
));
PyDict_SetItem
(
res_opt_m_i
,
res_l
,
res_opt_m_i_l
);
Py_DECREF
(
res_opt_m_i_l
);
PyObject
*
res_what_m_i_l
;
long
what_m_i_l
=
WHAT
(
m
,
i
,
l
);
if
(
what_m_i_l
<
0
)
res_what_m_i_l
=
Py_BuildValue
(
"(O)"
,
Py_True
);
else
res_what_m_i_l
=
Py_BuildValue
(
"(Ol)"
,
Py_False
,
what_m_i_l
);
PyDict_SetItem
(
res_what_m_i
,
res_l
,
res_what_m_i_l
);
Py_DECREF
(
res_what_m_i_l
);
Py_DECREF
(
res_l
);
}
}
}
free
(
opt
);
free
(
what
);
PyObject
*
result
=
PyTuple_Pack
(
2
,
res_opt
,
res_what
);
Py_DECREF
(
res_opt
);
Py_DECREF
(
res_what
);
return
result
;
}
// long i = L - s, j = t - s, k = l - t
inline
long
floating_index_in_array
(
long
m_factor
,
long
m
,
long
i
,
long
j
,
long
k
)
{
return
m
*
m_factor
+
(
i
*
(
i
+
1
)
*
(
2
*
i
+
4
))
/
12
+
(
i
+
1
)
*
j
-
(
j
*
(
j
-
1
))
/
2
+
k
;
}
typedef
struct
{
long
sp
;
long
r
;
long
tp
;
}
index_t
;
static
PyObject
*
floating_compute_table
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
chain_param
;
int
mmax
;
if
(
!
PyArg_ParseTuple
(
args
,
"Oi"
,
&
chain_param
,
&
mmax
))
return
NULL
;
double
*
fw
=
getDoubleArray
(
chain_param
,
"fweigth"
);
if
(
!
fw
)
return
NULL
;
double
*
bw
=
getDoubleArray
(
chain_param
,
"bweigth"
);
if
(
!
bw
)
return
NULL
;
long
*
cw
=
getLongArray
(
chain_param
,
"cweigth"
);
if
(
!
cw
)
return
NULL
;
long
*
cbw
=
getLongArray
(
chain_param
,
"cbweigth"
);
if
(
!
cbw
)
return
NULL
;
long
*
fwd_tmp
=
getLongArray
(
chain_param
,
"fwd_tmp"
);
if
(
!
fwd_tmp
)
return
NULL
;
long
*
bwd_tmp
=
getLongArray
(
chain_param
,
"bwd_tmp"
);
if
(
!
bwd_tmp
)
return
NULL
;
PyObject
*
chain_length_param
=
PyObject_GetAttrString
(
chain_param
,
"length"
);
if
(
!
chain_length_param
)
return
NULL
;
long
chain_length
=
PyLong_AsLong
(
chain_length_param
);
Py_DECREF
(
chain_length_param
);
const
long
m_factor
=
(
chain_length
+
1
)
*
(
chain_length
+
2
)
*
(
2
*
chain_length
+
6
)
/
12
;
// Defined for 0 <= s <= t <= l <= chain_length, for all m
#undef OPT
#define OPT(m, s, t, l) \
opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
double
*
opt
=
(
double
*
)
calloc
((
mmax
+
1
)
*
m_factor
,
sizeof
(
double
));
#undef WHAT
#define WHAT(m, s, t, l) \
what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
index_t
*
what
=
(
index_t
*
)
calloc
((
mmax
+
1
)
*
m_factor
,
sizeof
(
index_t
));
double
*
partialSumsFW
=
(
double
*
)
calloc
(
chain_length
+
1
,
sizeof
(
double
));
double
total
=
0
;
for
(
long
i
=
0
;
i
<
chain_length
;
++
i
)
{
partialSumsFW
[
i
]
=
total
;
total
+=
fw
[
i
];
}
partialSumsFW
[
chain_length
]
=
total
;
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
{
// TODO: Can be optimized to remove the IF by reordering loops
if
((
m
>=
cw
[
i
]
+
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
bwd_tmp
[
i
])
&&
(
m
>=
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
fwd_tmp
[
i
]))
OPT
(
m
,
i
,
i
,
i
)
=
fw
[
i
]
+
bw
[
i
];
else
OPT
(
m
,
i
,
i
,
i
)
=
INFINITY
;
}
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
d
=
1
;
d
<=
chain_length
;
++
d
)
{
// d = l - s
for
(
long
s
=
0
;
s
<=
chain_length
-
d
;
++
s
)
{
long
l
=
s
+
d
;
long
memNullFirst
=
cw
[
l
+
1
]
+
cw
[
s
+
1
]
+
fwd_tmp
[
s
];
long
memNullSecond
=
0
;
for
(
long
j
=
s
+
1
;
j
<
l
;
++
j
)
{
long
val
=
cw
[
j
]
+
cw
[
j
+
1
]
+
fwd_tmp
[
j
];
if
(
val
>
memNullSecond
)
memNullSecond
=
val
;
}
for
(
long
t
=
s
;
t
<=
l
;
++
t
)
{
double
chainCost
=
INFINITY
;
if
((
s
==
t
)
&&
(
m
>=
cw
[
l
+
1
]
+
cbw
[
s
+
1
]
+
fwd_tmp
[
s
])
&&
(
m
>=
cw
[
s
]
+
cw
[
s
+
1
]
+
cbw
[
s
+
1
]
+
bwd_tmp
[
s
]))
{
chainCost
=
OPT
(
m
,
s
,
s
,
s
)
+
OPT
(
m
-
cbw
[
s
+
1
],
s
+
1
,
s
+
1
,
l
);
}
double
bestLeafCost
=
INFINITY
;
index_t
bestLeaf
=
{.
sp
=
-
1
,
.
r
=
-
1
,
.
tp
=
-
1
};
if
(
m
>=
memNullFirst
&&
m
>=
cw
[
l
+
1
]
+
memNullSecond
)
{
for
(
long
r
=
s
;
r
<=
t
;
++
r
)
if
(
cw
[
s
]
<=
cw
[
r
])
for
(
long
tp
=
t
+
1
;
tp
<=
l
;
++
tp
)
for
(
long
sp
=
r
+
1
;
sp
<=
tp
;
++
sp
)
{
long
mp
=
m
-
cw
[
r
]
+
cw
[
s
];
assert
(
mp
>=
0
);
if
(
mp
>=
cw
[
sp
])
{
double
value
=
partialSumsFW
[
sp
]
-
partialSumsFW
[
s
]
+
OPT
(
mp
-
cw
[
sp
],
sp
,
tp
,
l
)
+
OPT
(
mp
,
r
,
t
,
tp
-
1
);
if
(
value
<
bestLeafCost
)
{
bestLeafCost
=
value
;
bestLeaf
.
sp
=
sp
;
bestLeaf
.
r
=
r
;
bestLeaf
.
tp
=
tp
;
}
}
}
}
if
(
bestLeaf
.
sp
>=
0
&&
bestLeafCost
<=
chainCost
)
{
OPT
(
m
,
s
,
t
,
l
)
=
bestLeafCost
;
WHAT
(
m
,
s
,
t
,
l
).
sp
=
bestLeaf
.
sp
;
WHAT
(
m
,
s
,
t
,
l
).
r
=
bestLeaf
.
r
;
WHAT
(
m
,
s
,
t
,
l
).
tp
=
bestLeaf
.
tp
;
}
else
{
OPT
(
m
,
s
,
t
,
l
)
=
chainCost
;
WHAT
(
m
,
s
,
t
,
l
).
sp
=
-
1
;
}
}
}
}
free
(
fw
);
free
(
bw
);
free
(
cw
);
free
(
cbw
);
free
(
fwd_tmp
);
free
(
bwd_tmp
);
PyObject
*
res_opt
=
PyList_New
(
mmax
+
1
);
PyObject
*
res_what
=
PyList_New
(
mmax
+
1
);
// Convert the result into Python world
PyObject
*
true_tuple
=
Py_BuildValue
(
"(O)"
,
Py_True
);
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
{
PyObject
*
res_opt_m
=
PyDict_New
();
PyList_SET_ITEM
(
res_opt
,
m
,
res_opt_m
);
PyObject
*
res_what_m
=
PyDict_New
();
PyList_SET_ITEM
(
res_what
,
m
,
res_what_m
);
for
(
long
s
=
0
;
s
<=
chain_length
;
++
s
)
for
(
long
t
=
s
;
t
<=
chain_length
;
++
t
)
for
(
long
l
=
t
;
l
<=
chain_length
;
++
l
)
{
PyObject
*
key
=
Py_BuildValue
(
"(lll)"
,
s
,
t
,
l
);
PyObject
*
value_opt
=
PyFloat_FromDouble
(
OPT
(
m
,
s
,
t
,
l
));
PyDict_SetItem
(
res_opt_m
,
key
,
value_opt
);
PyObject
*
value_what
=
true_tuple
;
index_t
*
idx_what
=
&
WHAT
(
m
,
s
,
t
,
l
);
if
(
idx_what
->
sp
>=
0
)
value_what
=
Py_BuildValue
(
"(O(lll))"
,
Py_False
,
idx_what
->
sp
,
idx_what
->
r
,
idx_what
->
tp
);
PyDict_SetItem
(
res_what_m
,
key
,
value_what
);
if
(
value_what
!=
true_tuple
)
Py_DECREF
(
value_what
);
Py_DECREF
(
key
);
Py_DECREF
(
value_opt
);
}
}
Py_DECREF
(
true_tuple
);
free
(
opt
);
free
(
what
);
PyObject
*
result
=
PyTuple_Pack
(
2
,
res_opt
,
res_what
);
Py_DECREF
(
res_opt
);
Py_DECREF
(
res_what
);
return
result
;
}
static
PyObject
*
griewank_heterogeneous_compute_table
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
chain_param
;
int
mmax
;
if
(
!
PyArg_ParseTuple
(
args
,
"Oi"
,
&
chain_param
,
&
mmax
))
return
NULL
;
double
*
fw
=
getDoubleArray
(
chain_param
,
"fweigth"
);
if
(
!
fw
)
return
NULL
;
double
*
bw
=
getDoubleArray
(
chain_param
,
"bweigth"
);
if
(
!
bw
)
return
NULL
;
long
*
cw
=
getLongArray
(
chain_param
,
"cweigth"
);
if
(
!
cw
)
return
NULL
;
long
*
cbw
=
getLongArray
(
chain_param
,
"cbweigth"
);
if
(
!
cbw
)
return
NULL
;
PyObject
*
chain_length_param
=
PyObject_GetAttrString
(
chain_param
,
"length"
);
if
(
!
chain_length_param
)
return
NULL
;
long
chain_length
=
PyLong_AsLong
(
chain_length_param
);
Py_DECREF
(
chain_length_param
);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#undef OPT
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double
*
opt
=
(
double
*
)
calloc
(
(
mmax
+
1
)
*
(
chain_length
+
1
)
*
(
chain_length
+
1
),
sizeof
(
double
));
// Compute partial sums
double
*
sumfw
=
(
double
*
)
calloc
(
chain_length
,
sizeof
(
double
));
double
*
sumbw
=
(
double
*
)
calloc
(
chain_length
+
1
,
sizeof
(
double
));
double
*
sumsumfw
=
(
double
*
)
calloc
(
chain_length
,
sizeof
(
double
));
double
total
=
0
;
for
(
long
i
=
0
;
i
<
chain_length
;
++
i
)
{
total
+=
fw
[
i
];
sumfw
[
i
]
=
total
;
}
total
=
0
;
for
(
long
i
=
0
;
i
<
chain_length
+
1
;
++
i
)
{
total
+=
bw
[
i
];
sumbw
[
i
]
=
total
;
}
total
=
0
;
for
(
long
i
=
0
;
i
<
chain_length
;
++
i
)
{
total
+=
sumfw
[
i
];
sumsumfw
[
i
]
=
total
;
}
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
{
// TODO: Can be optimized to remove the IF by reordering loops
if
((
m
>=
cbw
[
i
])
&&
(
m
>=
cw
[
i
]
+
cbw
[
i
+
1
]))
OPT
(
m
,
i
,
i
)
=
bw
[
i
];
else
OPT
(
m
,
i
,
i
)
=
INFINITY
;
if
(
i
<
chain_length
)
{
long
maxC
=
fmaxl
(
cw
[
i
],
cw
[
i
+
1
]);
long
maxCB
=
fmaxl
(
cbw
[
i
+
1
],
cbw
[
i
+
2
]
+
maxC
);
if
((
m
>=
cbw
[
i
])
&&
(
m
>=
cw
[
i
]
+
maxCB
))
OPT
(
m
,
i
,
i
+
1
)
=
fw
[
i
]
+
bw
[
i
]
+
bw
[
i
+
1
];
else
OPT
(
m
,
i
,
i
+
1
)
=
INFINITY
;
}
}
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
i
=
0
;
i
+
2
<=
chain_length
;
++
i
)
{
long
mminCst
=
fmaxl
(
cbw
[
i
],
cbw
[
i
+
1
]
+
cw
[
i
]);
long
maxCW_il
=
fmax
(
fmax
(
cw
[
i
],
cw
[
i
+
1
]),
cw
[
i
+
2
]);
long
maxCostFWD
=
cw
[
i
]
+
cbw
[
i
+
2
]
+
maxCW_il
;
for
(
long
l
=
i
+
2
;
l
<=
chain_length
;
++
l
)
{
maxCW_il
=
fmax
(
maxCW_il
,
cw
[
l
+
1
]);
maxCostFWD
=
fmaxl
(
maxCostFWD
,
cw
[
i
]
+
cw
[
l
+
1
]
+
maxCW_il
);
long
mmin
=
fmaxl
(
mminCst
,
maxCostFWD
);
if
((
m
>=
mmin
))
{
double
noCheckpointCost
=
sumbw
[
l
]
-
(
i
>
0
?
sumbw
[
i
-
1
]
:
0
);
noCheckpointCost
+=
sumsumfw
[
l
-
1
]
-
(
i
>
0
?
sumsumfw
[
i
-
1
]
+
(
l
-
i
)
*
sumfw
[
i
-
1
]
:
0
);
double
valueCost
=
INFINITY
;
if
(
m
>=
cw
[
i
])
{
double
sumFwds
=
0
;
for
(
long
j
=
i
+
1
;
j
<
l
;
++
j
)
{
sumFwds
+=
fw
[
j
-
1
];
valueCost
=
fmin
(
valueCost
,
sumFwds
+
OPT
(
m
-
cw
[
i
],
j
,
l
)
+
OPT
(
m
,
i
,
j
-
1
));
}
}
OPT
(
m
,
i
,
l
)
=
fmin
(
noCheckpointCost
,
valueCost
);
}
else
OPT
(
m
,
i
,
l
)
=
INFINITY
;
}
}
free
(
sumfw
);
free
(
sumbw
);
free
(
sumsumfw
);
free
(
fw
);
free
(
bw
);
free
(
cw
);
free
(
cbw
);
PyObject
*
res_opt
=
PyList_New
(
mmax
+
1
);
// Convert the result into Python world
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
{
PyObject
*
res_opt_m
=
PyList_New
(
chain_length
+
1
);
PyList_SET_ITEM
(
res_opt
,
m
,
res_opt_m
);
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
{
PyObject
*
res_opt_m_i
=
PyDict_New
();
PyList_SET_ITEM
(
res_opt_m
,
i
,
res_opt_m_i
);
for
(
long
l
=
i
;
l
<=
chain_length
;
++
l
)
{
PyObject
*
res_l
=
PyLong_FromLong
(
l
-
i
);
PyObject
*
res_opt_m_i_l
=
PyFloat_FromDouble
(
OPT
(
m
,
i
,
l
));
PyDict_SetItem
(
res_opt_m_i
,
res_l
,
res_opt_m_i_l
);
Py_DECREF
(
res_opt_m_i_l
);
Py_DECREF
(
res_l
);
}
}
}
free
(
opt
);
return
res_opt
;
}
static
PyMethodDef
dynamic_programs_methods
[]
=
{
{
"persistent_compute_table"
,
persistent_compute_table
,
METH_VARARGS
,
"Compute the optimal table with the persistent algorithm."
},
{
"floating_compute_table"
,
floating_compute_table
,
METH_VARARGS
,
"Compute the optimal table with the floating algorithm."
},
{
"griewank_heterogeneous_compute_table"
,
griewank_heterogeneous_compute_table
,
METH_VARARGS
,
"Compute the optimal table for the Griewank Heterogeneous Model."
},
{
NULL
,
NULL
,
0
,
NULL
}
/* Sentinel */
};
static
struct
PyModuleDef
dynamic_programs_module
=
{
PyModuleDef_HEAD_INIT
,
"dynamic_programs_C_version"
,
/* name of module */
NULL
,
/* module documentation, may be NULL */
-
1
,
/* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
dynamic_programs_methods
};
PyMODINIT_FUNC
PyInit_dynamic_programs_C_version
(
void
)
{
return
PyModule_Create
(
&
dynamic_programs_module
);
}
colossalai/fx/passes/algorithms/linearize.py
deleted
100644 → 0
View file @
55dcd305
from
typing
import
List
,
Any
from
torch.fx
import
GraphModule
,
Node
from
colossalai.fx.profiler
import
is_inplace
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
COPS
=
[
"getattr"
,
"getitem"
,
"size"
]
def
_is_cop
(
target
:
Any
)
->
bool
:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if
isinstance
(
target
,
str
):
return
target
in
COPS
else
:
return
target
.
__name__
in
COPS
def
linearize
(
gm
:
GraphModule
,
cnode
:
List
[
str
]
=
None
)
->
List
[
List
[
Node
]]:
"""Linearizing the graph
Args:
gm (GraphModule): GraphModule derived by tracing
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
We merge the inplace ops into the previous node.
"""
def
_is_sink
()
->
bool
:
"""Check if we can free all dependencies
Returns:
bool
"""
return
not
sum
([
v
for
_
,
v
in
deps
.
items
()])
and
not
any
(
map
(
is_inplace
,
n
.
users
))
# make sure that item in cnode is valid
if
cnode
:
for
name
in
cnode
:
try
:
assert
next
(
node
for
node
in
gm
.
graph
.
nodes
if
node
.
name
==
name
).
op
==
"placeholder"
,
\
f
"common node
{
name
}
is not an input of the model"
except
StopIteration
:
raise
ValueError
(
f
"common node name
{
name
}
not in graph"
)
else
:
cnode
=
[]
deps
=
{}
linearized_nodes
=
[]
region
=
[]
for
n
in
gm
.
graph
.
nodes
:
if
n
.
op
!=
"placeholder"
and
n
.
op
!=
"output"
:
for
n_par
in
n
.
_input_nodes
:
if
n_par
.
op
!=
"placeholder"
and
n_par
.
name
not
in
cnode
:
deps
[
n_par
]
-=
1
region
.
append
(
n
)
# if the node could free all dependencies in graph
# we could begin a new node
if
_is_sink
():
linearized_nodes
.
append
(
region
)
region
=
[]
# propagate common node attr if possible
if
len
(
n
.
_input_nodes
)
==
len
([
node
for
node
in
n
.
_input_nodes
if
node
.
name
in
cnode
])
or
_is_cop
(
n
.
target
):
cnode
.
append
(
n
.
name
)
else
:
deps
[
n
]
=
len
([
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
return
linearized_nodes
colossalai/fx/passes/algorithms/operation.py
deleted
100644 → 0
View file @
55dcd305
import
math
def
_discretize
(
mem_unit
,
values
):
return
[
math
.
ceil
(
value
/
mem_unit
)
for
value
in
values
]
class
Chain
:
def
__init__
(
self
,
fw
,
bw
,
cw
,
cbw
,
ftmp
,
btmp
,
check
=
True
):
self
.
fweight
=
fw
self
.
bweight
=
bw
self
.
cweight
=
cw
self
.
cbweight
=
cbw
self
.
fwd_mem_tmp
=
ftmp
self
.
bwd_mem_tmp
=
btmp
self
.
length
=
len
(
fw
)
if
check
and
not
self
.
check_lengths
():
raise
AttributeError
(
"In Chain, input lists do not have consistent lengths"
)
def
check_lengths
(
self
):
return
((
len
(
self
.
fweight
)
==
self
.
length
)
and
(
len
(
self
.
bweight
)
==
self
.
length
+
1
)
and
(
len
(
self
.
cweight
)
==
self
.
length
+
1
)
and
(
len
(
self
.
fwd_mem_tmp
)
==
self
.
length
)
and
(
len
(
self
.
bwd_mem_tmp
)
==
self
.
length
+
1
)
and
(
len
(
self
.
cbweight
)
==
self
.
length
+
1
))
def
__repr__
(
self
):
chain_list
=
[]
for
i
in
range
(
self
.
length
):
chain_list
.
append
((
self
.
fweight
[
i
],
self
.
bweight
[
i
],
self
.
cweight
[
i
],
self
.
cbweight
[
i
],
self
.
fwd_mem_tmp
[
i
],
self
.
bwd_mem_tmp
[
i
]))
i
=
self
.
length
chain_list
.
append
((
None
,
self
.
bweight
[
i
],
self
.
cweight
[
i
],
self
.
cbweight
[
i
],
None
,
self
.
bwd_mem_tmp
[
i
]))
return
chain_list
.
__repr__
()
def
_discretize
(
self
,
mem_unit
):
self
.
cweight
=
_discretize
(
mem_unit
,
self
.
cweight
)
self
.
cbweight
=
_discretize
(
mem_unit
,
self
.
cbweight
)
self
.
fwd_mem_tmp
=
_discretize
(
mem_unit
,
self
.
fwd_mem_tmp
)
self
.
bwd_mem_tmp
=
_discretize
(
mem_unit
,
self
.
bwd_mem_tmp
)
class
Operation
:
def
shift
(
self
,
value
):
if
type
(
self
.
index
)
is
tuple
:
self
.
index
=
tuple
(
x
+
value
for
x
in
self
.
index
)
else
:
self
.
index
+=
value
class
Offload
(
Operation
):
def
__init__
(
self
,
index
,
has_bar
=
False
)
->
None
:
super
().
__init__
()
self
.
index
=
index
self
.
name
=
"Off"
self
.
has_bar
=
has_bar
if
self
.
has_bar
:
self
.
name
+=
"wBar"
def
__repr__
(
self
):
return
f
"
{
self
.
name
}
_
{
self
.
index
}
"
class
Prefetch
(
Operation
):
def
__init__
(
self
,
index
,
has_bar
=
False
)
->
None
:
super
().
__init__
()
self
.
index
=
index
self
.
name
=
"Pre"
self
.
has_bar
=
has_bar
if
self
.
has_bar
:
self
.
name
+=
"wBar"
def
__repr__
(
self
):
return
f
"
{
self
.
name
}
_
{
self
.
index
}
"
class
Forward
(
Operation
):
def
__init__
(
self
,
index
):
self
.
index
=
index
self
.
name
=
"F"
def
__repr__
(
self
):
return
"{n}_{i}"
.
format
(
n
=
self
.
name
,
i
=
self
.
index
)
def
cost
(
self
,
chain
:
Chain
):
if
chain
is
not
None
:
return
chain
.
fweight
[
self
.
index
]
else
:
return
1
class
ForwardEnable
(
Forward
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"Fe"
class
ForwardNograd
(
Forward
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"Fn"
class
ForwardCheck
(
Forward
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"CF"
class
Forwards
(
Operation
):
def
__init__
(
self
,
start
,
end
):
self
.
index
=
(
start
,
end
)
def
__repr__
(
self
):
return
"F_{i}->{j}"
.
format
(
i
=
self
.
index
[
0
],
j
=
self
.
index
[
1
])
def
cost
(
self
,
chain
:
Chain
):
if
chain
is
not
None
:
return
sum
(
chain
.
fweight
[
self
.
index
[
0
]:
self
.
index
[
1
]
+
1
])
else
:
return
(
self
.
index
[
1
]
-
self
.
index
[
0
]
+
1
)
def
isForward
(
op
):
return
type
(
op
)
is
Forward
or
type
(
op
)
is
Forwards
class
Backward
(
Operation
):
def
__init__
(
self
,
index
):
self
.
index
=
index
def
__repr__
(
self
):
return
"B_{i}"
.
format
(
i
=
self
.
index
)
def
cost
(
self
,
chain
:
Chain
):
if
chain
is
not
None
:
return
chain
.
bweight
[
self
.
index
]
else
:
return
1
class
Loss
(
Operation
):
def
__init__
(
self
):
pass
def
__repr__
(
self
):
return
"L"
def
cost
(
self
,
chain
):
return
0
class
MemoryAccess
(
Operation
):
def
__init__
(
self
,
index
):
self
.
index
=
index
def
__repr__
(
self
):
return
"{n}_{i}"
.
format
(
n
=
self
.
name
,
i
=
self
.
index
)
def
cost
(
self
,
chain
:
Chain
):
return
0
class
WriteMemory
(
MemoryAccess
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"WM"
class
ReadMemory
(
MemoryAccess
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"RM"
class
DiscardMemory
(
MemoryAccess
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"DM"
class
Function
:
def
__init__
(
self
,
name
,
*
args
):
self
.
name
=
name
self
.
args
=
args
self
.
str_args
=
','
.
join
(
str
(
v
)
for
v
in
self
.
args
)
def
__repr__
(
self
):
return
"{n}({args})"
.
format
(
n
=
self
.
name
,
args
=
self
.
str_args
)
class
Sequence
:
def
__init__
(
self
,
function
):
self
.
sequence
=
[]
#List of Operation and Sequence
self
.
function
=
function
#Description the function (name and parameters)
def
__repr__
(
self
):
return
repr
(
self
.
list_operations
())
def
list_operations
(
self
):
op_list
=
[]
for
x
in
self
.
sequence
:
if
isinstance
(
x
,
Operation
):
op_list
.
append
(
x
)
else
:
assert
isinstance
(
x
,
Sequence
)
op_list
+=
x
.
list_operations
()
return
op_list
def
insert
(
self
,
operation
):
self
.
sequence
.
append
(
operation
)
def
remove
(
self
,
operation_index
):
del
self
.
sequence
[
operation_index
]
def
insert_sequence
(
self
,
sequence
):
self
.
sequence
.
append
(
sequence
)
def
shift
(
self
,
value
):
for
x
in
self
.
sequence
:
x
.
shift
(
value
)
return
self
def
remove_useless_write
(
self
):
if
self
.
sequence
:
if
isinstance
(
self
.
sequence
[
0
],
WriteMemory
):
self
.
remove
(
0
)
return
self
def
get_makespan
(
self
,
chain
):
return
sum
(
op
.
cost
(
chain
)
for
op
in
self
.
list_operations
())
def
without_suffix
(
self
):
ops
=
self
.
list_operations
()
end_of_first_phase
=
[
i
for
i
in
range
(
len
(
ops
))
if
type
(
ops
[
i
])
is
Loss
][
0
]
try
:
last_idx
=
max
(
i
for
i
in
range
(
end_of_first_phase
)
if
not
type
(
ops
[
i
])
is
ForwardEnable
)
except
ValueError
:
last_idx
=
-
1
if
last_idx
==
end_of_first_phase
-
1
:
return
(
self
,
None
)
chain_length
=
ops
[
end_of_first_phase
-
1
].
index
## Some assumption here about the sequence (finishes with Forward_L
start_of_fwd_enable_chain
=
ops
[
last_idx
+
1
].
index
## And starts with B_L), but should be fine in practice
result
=
Sequence
(
Function
(
"Strip"
,
self
.
function
.
name
,
*
self
.
function
.
args
,
start_of_fwd_enable_chain
))
for
i
in
range
(
last_idx
+
1
):
result
.
insert
(
ops
[
i
])
result
.
insert
(
Loss
())
for
i
in
range
(
chain_length
,
start_of_fwd_enable_chain
-
1
,
-
1
):
position
=
end_of_first_phase
+
1
+
(
chain_length
-
i
)
assert
type
(
ops
[
position
])
is
Backward
assert
ops
[
position
].
index
==
i
for
i
in
range
(
end_of_first_phase
+
1
+
1
+
chain_length
-
start_of_fwd_enable_chain
,
len
(
ops
)):
result
.
insert
(
ops
[
i
])
return
(
result
,
start_of_fwd_enable_chain
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment