Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b42d3d28
Unverified
Commit
b42d3d28
authored
Mar 07, 2023
by
Super Daniel
Committed by
GitHub
Mar 07, 2023
Browse files
[fx] remove depreciated algorithms. (#2312) (#2313)
parent
55dcd305
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
0 additions
and
1970 deletions
+0
-1970
colossalai/fx/passes/algorithms/__init__.py
colossalai/fx/passes/algorithms/__init__.py
+0
-4
colossalai/fx/passes/algorithms/build_c_ext.py
colossalai/fx/passes/algorithms/build_c_ext.py
+0
-15
colossalai/fx/passes/algorithms/ckpt_solver_chen.py
colossalai/fx/passes/algorithms/ckpt_solver_chen.py
+0
-98
colossalai/fx/passes/algorithms/ckpt_solver_pofo.py
colossalai/fx/passes/algorithms/ckpt_solver_pofo.py
+0
-537
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
+0
-436
colossalai/fx/passes/algorithms/dynamic_programs.c
colossalai/fx/passes/algorithms/dynamic_programs.c
+0
-516
colossalai/fx/passes/algorithms/linearize.py
colossalai/fx/passes/algorithms/linearize.py
+0
-94
colossalai/fx/passes/algorithms/operation.py
colossalai/fx/passes/algorithms/operation.py
+0
-270
No files found.
colossalai/fx/passes/algorithms/__init__.py
deleted
100644 → 0
View file @
55dcd305
from
.ckpt_solver_chen
import
chen_greedy
from
.linearize
import
linearize
from
.ckpt_solver_rotor
import
solver_rotor
from
.ckpt_solver_pofo
import
solver_pofo
colossalai/fx/passes/algorithms/build_c_ext.py
deleted
100644 → 0
View file @
55dcd305
from
setuptools
import
setup
,
Extension
import
os
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
ext_modules
=
[
Extension
(
'dynamic_programs_C_version'
,
sources
=
[
os
.
path
.
join
(
this_dir
,
'dynamic_programs.c'
)],
)]
setup
(
name
=
'rotor c extension'
,
version
=
'0.1'
,
description
=
'rotor c extension for faster dp computing'
,
ext_modules
=
ext_modules
,
)
colossalai/fx/passes/algorithms/ckpt_solver_chen.py
deleted
100644 → 0
View file @
55dcd305
import
math
from
typing
import
List
,
Set
,
Tuple
import
torch
from
torch.fx
import
GraphModule
,
Node
from
colossalai.fx.profiler
import
calculate_fwd_in
,
calculate_fwd_tmp
__all__
=
[
'chen_greedy'
]
CKPT_OP
=
[
'call_module'
,
'call_method'
,
'call_function'
,
'get_attr'
]
def
_all_potential_ckpt_nodes
(
gm
:
GraphModule
)
->
List
:
"""
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
"""
def
is_sink
():
"""
If we can free all memories when executing a certain node, it is a sink.
"""
return
not
sum
((
v
for
k
,
v
in
deps
.
items
()))
deps
=
{}
ckpt_nodes
=
[]
for
n
in
gm
.
graph
.
nodes
:
for
n_par
in
n
.
_input_nodes
:
deps
[
n_par
]
-=
1
# free memory and dependencies
# We can only put act_ckpt on these nodes
if
n
.
op
in
CKPT_OP
and
is_sink
():
ckpt_nodes
.
append
(
n
)
deps
[
n
]
=
len
(
n
.
users
)
# add dependencies for future executions
return
ckpt_nodes
def
chen_greedy
(
gm
:
GraphModule
)
->
GraphModule
:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
model = resnet18()
input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
gm = chen_greedy(gm)
Args:
gm (GraphModule): The module to add checkpoints
"""
def
grid_search
(
num_grids
:
int
=
6
)
->
Set
:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
"""
_
,
b_approx
=
run_chen_greedy
(
0
)
b_min
,
b_max
=
math
.
floor
(
b_approx
/
math
.
sqrt
(
2
)),
math
.
ceil
(
b_approx
*
math
.
sqrt
(
2
))
b_opt
=
math
.
inf
for
b
in
range
(
b_min
,
b_max
,
(
b_max
-
b_min
)
//
num_grids
):
ckpt_intv
,
b_approx
=
run_chen_greedy
(
b
)
if
b_approx
<
b_opt
:
b_opt
=
b_approx
ckpt_opt
=
ckpt_intv
return
ckpt_opt
def
run_chen_greedy
(
b
:
int
=
0
)
->
Tuple
[
Set
,
int
]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt_nodes
=
_all_potential_ckpt_nodes
(
gm
)
ckpt_intv
=
[]
temp
=
0
x
=
0
y
=
0
prev_idx
=
2
for
(
idx
,
n
)
in
enumerate
(
gm
.
graph
.
nodes
):
n
:
Node
temp
+=
calculate_fwd_in
(
n
)
+
calculate_fwd_tmp
(
n
)
y
=
max
(
y
,
temp
)
if
temp
>
b
and
n
in
ckpt_nodes
:
x
+=
calculate_fwd_in
(
n
)
temp
=
0
ckpt_intv
.
append
((
prev_idx
,
idx
+
1
))
prev_idx
=
idx
+
1
return
ckpt_intv
,
math
.
floor
(
math
.
sqrt
(
x
*
y
))
gm
.
graph
.
lint
()
# make sure nodes are in topological order
ckpt
=
grid_search
(
num_grids
=
6
)
node_list
=
list
(
gm
.
graph
.
nodes
)
for
i
,
seg
in
enumerate
(
ckpt
):
for
idx
in
range
(
*
seg
):
n
=
node_list
[
idx
]
if
n
.
op
in
CKPT_OP
:
setattr
(
n
,
'activation_checkpoint'
,
i
)
gm
.
recompile
()
return
gm
colossalai/fx/passes/algorithms/ckpt_solver_pofo.py
deleted
100644 → 0
View file @
55dcd305
import
copy
import
math
from
typing
import
List
,
Tuple
import
torch
from
colossalai.fx
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
\
_find_nested_ckpt_regions
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.algorithms.ckpt_solver_rotor
import
(
_compute_table
,
_construct_chain
,
_rec
)
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
parameter_size
from
torch.fx
import
GraphModule
,
Node
from
.linearize
import
linearize
from
.operation
import
(
Backward
,
Chain
,
ForwardCheck
,
ForwardEnable
,
ForwardNograd
,
Function
,
Loss
,
Offload
,
Prefetch
,
Sequence
)
INF
=
float
(
"inf"
)
def
_normalize_flops
(
chain
:
Chain
,
flops
)
->
Chain
:
"""
Normalize flops
"""
for
i
in
range
(
chain
.
length
):
chain
.
fweight
[
i
]
/=
flops
chain
.
bweight
[
i
]
/=
flops
return
chain
class
PofoTable
:
"""PofoTable
The PofoTable contains the necessary components to store intermediate results
of dynamic programming and the operations alone the way.
"""
def
__init__
(
self
,
chain_length
:
int
,
mem_slots
:
int
):
"""Init pofo table
The pofo table contains two tables, opt and what, indicating values and
operations.
Args:
chain_length (int): chain length
mem_slots (int): number of memory slots
"""
self
.
length
=
chain_length
self
.
mem_slots
=
mem_slots
# initializing tables
# the first bool indicates whether the input has bar
# opt table is for value, opt[True/False][i][A][(df, db)] = OCx(i, A, df, db)
# what table is for decision, what[True/False][i][A][(df, db)] = (is_enable, is_offload, index)
# where is_enable indicates whether we enable the gradient, is_offload indicates whether we
# offload the input, index indicates the end of F_\empty sequence if is_enable = False
self
.
opt
=
{
False
:
[[{}
for
_
in
range
(
mem_slots
+
1
)]
for
_
in
range
(
self
.
length
+
1
)],
True
:
[[{}
for
_
in
range
(
mem_slots
+
1
)]
for
_
in
range
(
self
.
length
+
1
)]
}
self
.
what
=
{
False
:
[[{}
for
_
in
range
(
mem_slots
+
1
)]
for
_
in
range
(
self
.
length
+
1
)],
True
:
[[{}
for
_
in
range
(
mem_slots
+
1
)]
for
_
in
range
(
self
.
length
+
1
)]
}
def
_get_value
(
self
,
state
,
table
,
default
):
i
,
act_size
,
df
,
db
,
input_has_bar
=
state
if
act_size
+
df
>
self
.
mem_slots
or
act_size
+
db
>
self
.
mem_slots
:
return
default
try
:
return
table
[
input_has_bar
][
i
][
act_size
][(
df
,
db
)]
except
KeyError
:
print
(
f
"state not found
{
state
}
"
)
def
get_opt
(
self
,
state
):
return
self
.
_get_value
(
state
,
self
.
opt
,
INF
)
def
get_what
(
self
,
state
):
return
self
.
_get_value
(
state
,
self
.
what
,
INF
)
def
set_value
(
self
,
state
,
opt
,
what
):
i
,
act_size
,
df
,
db
,
input_has_bar
=
state
self
.
opt
[
input_has_bar
][
i
][
act_size
][(
df
,
db
)]
=
opt
self
.
what
[
input_has_bar
][
i
][
act_size
][(
df
,
db
)]
=
what
class
PofoSolver
:
"""PofoSolver that executes algorithm mentioned in https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
The new pofo solver is based on paper Efficient Combination of Rematerialization and Offloading for Training DNNs
and it's code given in the supplemental. Currently we doesn't use the whole set up in the original paper and reuse
rotor solver for the backward sequence as suggested in supplemental. The solver now is able to find strategy with offload.
"""
def
__init__
(
self
,
chain
:
Chain
,
max_memory
:
int
,
bandwidth
,
mem_slots
:
int
)
->
None
:
self
.
chain
=
chain
self
.
length
=
chain
.
length
self
.
max_memory
=
max_memory
self
.
mem_slots
=
mem_slots
self
.
mem_unit
=
max_memory
/
mem_slots
self
.
bandwidth
=
bandwidth
self
.
disc_chain
=
copy
.
deepcopy
(
self
.
chain
)
self
.
disc_chain
.
_discretize
(
self
.
mem_unit
)
self
.
rotor_table
=
_compute_table
(
self
.
disc_chain
,
mem_slots
)
self
.
_compute_pofo_table
()
def
_discretize
(
self
,
*
values
)
->
Tuple
:
return
tuple
(
math
.
ceil
(
value
/
self
.
mem_unit
)
for
value
in
values
)
def
_undiscretize
(
self
,
*
discrete_values
)
->
Tuple
:
if
len
(
discrete_values
)
==
1
:
return
discrete_values
[
0
]
*
self
.
mem_unit
else
:
return
tuple
(
d
*
self
.
mem_unit
for
d
in
discrete_values
)
def
_mmax_all
(
self
,
idx
:
int
):
"""
Calculate the maximum memory usage of Fi_all
"""
return
self
.
chain
.
cbweight
[
idx
+
1
]
+
self
.
chain
.
fwd_mem_tmp
[
idx
]
def
_mmax_b
(
self
,
idx
:
int
):
"""
Calculate the maximum memory usage of Bi
"""
return
self
.
chain
.
cbweight
[
idx
+
1
]
+
self
.
chain
.
cweight
[
idx
+
1
]
+
self
.
chain
.
cweight
[
idx
]
+
self
.
chain
.
bwd_mem_tmp
[
idx
]
def
_mmax_ng
(
self
,
i
:
int
,
j
:
int
):
"""
Calculate the maximum memory usage of CF_i, F_i+1\empty, ... F_j\empty
"""
res
=
self
.
chain
.
cweight
[
j
+
1
]
+
self
.
chain
.
fwd_mem_tmp
[
j
]
if
j
>
i
:
res
+=
self
.
chain
.
cweight
[
j
]
return
res
def
_rotor_estimated_bwd
(
self
,
i
,
j
,
m
,
delta
):
compute
=
self
.
rotor_table
[
0
][
math
.
floor
((
m
-
self
.
chain
.
cweight
[
i
])
/
self
.
mem_unit
)][
i
][
j
]
comm
=
delta
/
self
.
bandwidth
return
(
max
(
compute
,
comm
)
+
compute
+
comm
)
/
2
def
_rotor_estimated_bwd_sequence
(
self
,
i
,
j
,
m
,
delta
):
return
_rec
(
self
.
disc_chain
,
i
,
j
,
math
.
floor
((
m
-
self
.
chain
.
cweight
[
i
])
/
self
.
mem_unit
),
self
.
rotor_table
)
def
_common_values_enable
(
self
,
state
:
Tuple
):
idx
,
act_size
,
df
,
db
,
input_has_bar
=
state
input_size
=
self
.
chain
.
cbweight
[
idx
]
if
input_has_bar
else
self
.
chain
.
cweight
[
idx
]
mf
=
act_size
+
df
+
input_size
mb
=
act_size
+
db
+
input_size
mem_avail
=
self
.
max_memory
-
act_size
-
input_size
f_usage
=
self
.
_mmax_all
(
idx
)
b_usage
=
self
.
_mmax_b
(
idx
)
# infeasible
if
f_usage
>
mem_avail
or
b_usage
>
mem_avail
:
return
None
# calculate idle time
eps_f_beta
=
max
(
0
,
f_usage
-
self
.
max_memory
+
mf
)
eps_b_beta
=
max
(
0
,
b_usage
-
self
.
max_memory
+
mb
)
idle_time
=
(
eps_f_beta
+
eps_b_beta
)
/
self
.
bandwidth
# calculate offload and prefetch data
offload_data
=
self
.
chain
.
fweight
[
idx
]
*
self
.
bandwidth
+
eps_f_beta
prefetch_data
=
self
.
chain
.
bweight
[
idx
]
*
self
.
bandwidth
+
eps_b_beta
# total_time
total_time
=
self
.
chain
.
fweight
[
idx
]
+
self
.
chain
.
bweight
[
idx
]
+
idle_time
return
(
offload_data
,
prefetch_data
,
total_time
,
idle_time
)
def
_common_values_nograd
(
self
,
state
:
Tuple
,
j
:
int
,
iterative
:
bool
=
False
):
i
,
act_size
,
df
,
db
,
input_has_bar
=
state
# compute new epsilon_tmp and sum_fwds
if
iterative
:
self
.
epsilon_tmp
=
max
(
self
.
epsilon_tmp
,
self
.
_mmax_ng
(
i
,
j
)
-
self
.
bandwidth
*
self
.
sum_fwds
)
self
.
sum_fwds
+=
self
.
chain
.
fweight
[
j
]
else
:
self
.
epsilon_tmp
=
max
(
self
.
_mmax_ng
(
i
,
k
)
-
self
.
bandwidth
*
sum
(
self
.
chain
.
fweight
[
i
:
k
])
for
k
in
range
(
i
,
j
+
1
))
self
.
sum_fwds
=
sum
(
self
.
chain
.
fweight
[
i
:
j
+
1
])
input_size
=
self
.
chain
.
cbweight
[
i
]
if
input_has_bar
else
self
.
chain
.
cweight
[
i
]
mf
=
act_size
+
df
+
input_size
mem_avail
=
self
.
max_memory
-
act_size
-
input_size
# if infeasible
if
max
(
self
.
_mmax_ng
(
i
,
k
)
for
k
in
range
(
i
,
self
.
length
))
>
mem_avail
:
return
None
eps_f_beta
=
max
(
0
,
self
.
epsilon_tmp
-
self
.
max_memory
+
mf
)
offload_data
=
self
.
sum_fwds
*
self
.
bandwidth
+
eps_f_beta
# TODO: Implement the precise backward recompute sequence mentioned in the paper
# currently we will use an approximate way to get the backward time
time_backward
=
self
.
_rotor_estimated_bwd
(
i
,
j
,
mem_avail
,
db
)
prefetch_data
=
time_backward
*
self
.
bandwidth
idle_time
=
eps_f_beta
/
self
.
bandwidth
total_time
=
self
.
sum_fwds
+
idle_time
+
time_backward
return
(
offload_data
,
prefetch_data
,
total_time
,
idle_time
)
def
_new_values
(
self
,
state
:
Tuple
,
do_offload
:
bool
,
common_values
:
Tuple
)
->
Tuple
:
"""Generate new values for next state
Args:
state (Tuple): undiscretized states
do_offload (bool): bool type indicates whether we need to do offload
common_values (Tuple): common values (offload_data, prefetch_data, total_time, idle_time)
Returns:
Tuple: (new_act_size, new_df, new_db)
"""
idx
,
act_size
,
df
,
db
,
input_has_bar
=
state
offload_data
,
prefetch_data
,
*
_
=
common_values
input_size
=
self
.
chain
.
cbweight
[
idx
]
if
input_has_bar
else
self
.
chain
.
cweight
[
idx
]
if
do_offload
:
new_act_size
=
act_size
new_df
=
max
(
0
,
df
+
input_size
-
offload_data
)
new_db
=
max
(
0
,
db
-
prefetch_data
)
+
input_size
else
:
new_act_size
=
act_size
+
input_size
new_df
=
max
(
0
,
df
-
offload_data
)
new_db
=
max
(
0
,
db
-
prefetch_data
)
return
(
new_act_size
,
new_df
,
new_db
)
def
_compute_pofo_table
(
self
):
self
.
table
=
PofoTable
(
self
.
length
,
self
.
mem_slots
)
# initializing the loss
for
act_size
in
range
(
self
.
mem_slots
+
1
):
for
df
in
range
(
self
.
mem_slots
-
act_size
+
1
):
for
db
in
range
(
self
.
mem_slots
-
act_size
+
1
):
# undiscretize for idle time calculation
origin_values
=
self
.
_undiscretize
(
act_size
,
df
,
db
)
for
input_has_bar
in
(
False
,
True
):
disc_state
=
(
self
.
length
,
act_size
,
df
,
db
,
input_has_bar
)
state
=
(
self
.
length
,
*
origin_values
,
input_has_bar
)
common_values
=
self
.
_common_values_enable
(
state
)
# if no feasible choice
if
common_values
is
None
:
self
.
table
.
set_value
(
disc_state
,
INF
,
None
)
continue
# if there is feasible choice
new_act_size
,
new_df
,
new_db
=
self
.
_new_values
(
state
,
False
,
common_values
)
eps_g
=
(
new_df
+
new_db
)
/
self
.
bandwidth
total_time
=
common_values
[
2
]
+
eps_g
self
.
table
.
set_value
(
disc_state
,
total_time
,
(
True
,
False
))
# main loop
for
i
in
reversed
(
range
(
self
.
length
)):
for
act_size
in
range
(
self
.
mem_slots
+
1
):
for
df
in
range
(
self
.
mem_slots
-
act_size
+
1
):
for
db
in
range
(
self
.
mem_slots
-
act_size
+
1
):
# undiscretize for idle time calculation
origin_values
=
self
.
_undiscretize
(
act_size
,
df
,
db
)
for
input_has_bar
in
(
False
,
True
):
best_result
=
INF
best_choice
=
None
disc_state
=
(
i
,
act_size
,
df
,
db
,
input_has_bar
)
state
=
(
i
,
*
origin_values
,
input_has_bar
)
# case 1: start with F_all
vals_enable
=
self
.
_common_values_enable
(
state
)
if
vals_enable
is
not
None
:
for
do_offload
in
(
True
,
False
):
new_state
=
self
.
_new_values
(
state
,
do_offload
,
vals_enable
)
new_state
=
(
i
+
1
,
*
self
.
_discretize
(
*
new_state
),
True
)
total_time
=
vals_enable
[
2
]
results_all
=
self
.
table
.
get_opt
(
new_state
)
+
total_time
if
results_all
<
best_result
:
best_result
=
results_all
best_choice
=
(
True
,
do_offload
)
# case 2: start with F_ck
self
.
sum_fwds
=
0
self
.
epsilon_tmp
=
0
for
j
in
range
(
i
,
self
.
length
):
vals_nograd
=
self
.
_common_values_nograd
(
state
,
j
,
True
)
# if infeasible
if
vals_nograd
is
None
:
continue
for
do_offload
in
(
True
,
False
):
new_state
=
self
.
_new_values
(
state
,
do_offload
,
vals_nograd
)
new_state
=
(
j
+
1
,
*
self
.
_discretize
(
*
new_state
),
False
)
total_time
=
vals_nograd
[
2
]
result_nograd
=
total_time
+
self
.
table
.
get_opt
(
new_state
)
if
result_nograd
<
best_result
:
best_result
=
result_nograd
best_choice
=
(
False
,
do_offload
,
j
)
self
.
table
.
set_value
(
disc_state
,
best_result
,
best_choice
)
def
pofo_rec
(
self
,
disc_state
):
i
,
act_size
,
df
,
db
,
input_has_bar
=
disc_state
result
=
Sequence
(
Function
(
"pofo"
,
*
disc_state
))
what
=
self
.
table
.
get_what
(
disc_state
)
state
=
self
.
_undiscretize
(
act_size
,
df
,
db
)
state
=
(
i
,
*
state
,
input_has_bar
)
i
,
act_size
,
df
,
db
,
input_has_bar
=
state
if
what
is
None
:
return
None
# if loss
if
i
==
self
.
length
:
result
.
insert
(
Loss
())
return
result
if
what
[
0
]:
do_offload
=
what
[
1
]
values
=
self
.
_common_values_enable
(
state
)
new_state
=
self
.
_discretize
(
*
self
.
_new_values
(
state
,
do_offload
,
values
))
new_state
=
(
i
+
1
,
*
new_state
,
True
)
if
do_offload
:
result
.
insert
(
Offload
(
i
,
input_has_bar
))
result
.
insert
(
ForwardEnable
(
i
))
result
.
insert_sequence
(
self
.
pofo_rec
(
new_state
))
if
do_offload
:
result
.
insert
(
Prefetch
(
i
,
input_has_bar
))
result
.
insert
(
Backward
(
i
))
else
:
_
,
do_offload
,
j
=
what
values
=
self
.
_common_values_nograd
(
state
,
j
)
new_state
=
self
.
_discretize
(
*
self
.
_new_values
(
state
,
do_offload
,
values
))
new_state
=
(
j
+
1
,
*
new_state
,
False
)
if
do_offload
:
result
.
insert
(
Offload
(
i
,
input_has_bar
))
result
.
insert
(
ForwardCheck
(
i
))
for
k
in
range
(
i
+
1
,
j
+
1
):
result
.
insert
(
ForwardNograd
(
k
))
result
.
insert_sequence
(
self
.
pofo_rec
(
new_state
))
if
do_offload
:
result
.
insert
(
Prefetch
(
i
,
input_has_bar
))
m
=
self
.
max_memory
-
act_size
-
(
self
.
chain
.
cbweight
[
i
]
if
input_has_bar
else
self
.
chain
.
cweight
[
i
])
#TODO: Implement the precise backward recompute sequence mentioned in the paper
result
.
insert_sequence
(
self
.
_rotor_estimated_bwd_sequence
(
i
,
j
,
m
,
db
))
return
result
def
_annotate_from_pofo_sequence
(
sequence
:
Sequence
,
node_list
:
List
[
List
[
Node
]]):
op_list
=
sequence
.
list_operations
()
loss_op
=
next
(
op
for
op
in
op_list
if
isinstance
(
op
,
Loss
))
fwd_list
=
op_list
[:
op_list
.
index
(
loss_op
)]
bwd_list
=
op_list
[
op_list
.
index
(
loss_op
)
+
1
:]
ckpt_idx
=
0
in_ckpt
=
False
ckpt_region
=
[]
# forward annotation
for
op
in
fwd_list
:
if
in_ckpt
:
if
isinstance
(
op
,
ForwardNograd
):
ckpt_region
.
append
(
op
.
index
)
elif
isinstance
(
op
,
ForwardEnable
):
in_ckpt
=
False
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
setattr
(
n
,
"activation_checkpoint"
,
[
ckpt_idx
])
ckpt_idx
+=
1
ckpt_region
=
[]
elif
isinstance
(
op
,
ForwardCheck
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
setattr
(
n
,
"activation_checkpoint"
,
[
ckpt_idx
])
ckpt_idx
+=
1
ckpt_region
=
[
op
.
index
]
else
:
if
isinstance
(
op
,
ForwardCheck
):
in_ckpt
=
True
ckpt_region
.
append
(
op
.
index
)
# annotate the backward if there is any nested activation checkpoint
in_recompute
=
False
for
op
in
bwd_list
:
if
in_recompute
:
if
isinstance
(
op
,
ForwardNograd
):
ckpt_region
.
append
(
op
.
index
)
elif
isinstance
(
op
,
ForwardEnable
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
activation_checkpoint
.
append
(
ckpt_idx
)
ckpt_idx
+=
1
ckpt_region
=
[]
elif
isinstance
(
op
,
ForwardCheck
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
activation_checkpoint
.
append
(
ckpt_idx
)
ckpt_idx
+=
1
ckpt_region
=
[
op
.
index
]
elif
isinstance
(
op
,
Backward
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
activation_checkpoint
.
append
(
ckpt_idx
)
in_recompute
=
False
else
:
if
not
isinstance
(
op
,
Backward
):
in_recompute
=
True
ckpt_idx
=
0
ckpt_region
=
[]
if
isinstance
(
op
,
ForwardCheck
):
ckpt_region
.
append
(
op
.
index
)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list
=
[]
for
node
in
node_list
:
op_list
+=
node
ckpt_regions
=
_find_nested_ckpt_regions
(
op_list
)
for
(
start_idx
,
end_idx
)
in
ckpt_regions
:
nested_length
=
max
(
len
(
op_list
[
idx
].
activation_checkpoint
)
for
idx
in
range
(
start_idx
,
end_idx
+
1
))
for
idx
in
range
(
start_idx
,
end_idx
+
1
):
op_list
[
idx
].
activation_checkpoint
+=
[
None
]
*
(
nested_length
-
len
(
op_list
[
idx
].
activation_checkpoint
))
# annotate the offload
offload_idx
=
0
for
idx
,
op
in
enumerate
(
fwd_list
):
if
isinstance
(
op
,
Offload
):
# corner case: offload input
if
op
.
index
==
0
:
if
isinstance
(
fwd_list
[
idx
+
1
],
ForwardCheck
):
for
n
in
node_list
[
op
.
index
]:
setattr
(
n
,
"activation_offload"
,
True
)
else
:
for
n
in
node_list
[
op
.
index
]:
setattr
(
n
,
"activation_offload"
,
(
offload_idx
,
True
,
False
))
offload_idx
+=
1
else
:
if
op
.
has_bar
:
# annotate previous node
if
hasattr
(
node_list
[
op
.
index
-
1
][
0
],
"activation_offload"
):
for
n
in
node_list
[
op
.
index
-
1
]:
n
.
activation_offload
[
-
1
]
=
True
else
:
for
n
in
node_list
[
op
.
index
-
1
]:
setattr
(
n
,
"activation_offload"
,
[
offload_idx
,
False
,
True
])
offload_idx
+=
1
# annotate this node
if
isinstance
(
fwd_list
[
idx
+
1
],
ForwardCheck
):
for
n
in
node_list
[
op
.
index
]:
setattr
(
n
,
"activation_offload"
,
True
)
else
:
for
n
in
node_list
[
op
.
index
]:
setattr
(
n
,
"activation_offload"
,
[
offload_idx
,
True
,
False
])
offload_idx
+=
1
def
solver_pofo
(
gm
:
ColoGraphModule
,
data
,
bandwidth
,
flops
,
mem_limit
:
int
,
mem_slots
:
int
=
50
,
cnode
:
List
[
str
]
=
None
,
eps
:
float
=
0.0
)
->
ColoGraphModule
:
"""Solver that combine offload and activation checkpoint
Reference: https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
Args:
gm (ColoGraphModule): ColoGraphModule derived from tracer
data: input of the model
bandwidth: offload bandwidth, unit Byte/s
flops: FLOPS of device, unit FLOPs/s
mem_limit (int): memory limit, unit Byte
mem_slots (int, optional): number of memory slots. Defaults to 500.
cnode (List[str], optional): common node for linearize. Defaults to None.
eps (float, optional): epsilon for memory decay. Defaults to 0.02.
Returns:
ColoGraphModule: annotated graph module
"""
node_list
=
linearize
(
gm
,
cnode
)
mem_limit
-=
parameter_size
(
gm
)
# prepare data
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
data
=
MetaTensor
(
data
,
fake_device
=
next
(
gm
.
parameters
()).
device
)
MetaInfoProp
(
gm
).
run
(
data
)
chain
:
Chain
=
_construct_chain
(
node_list
,
data
)
chain
=
_normalize_flops
(
chain
,
flops
)
# currently we view loss as an op without expense
chain
.
cbweight
.
append
(
0
)
chain
.
cweight
.
append
(
0
)
chain
.
fwd_mem_tmp
.
append
(
0
)
chain
.
bwd_mem_tmp
.
append
(
0
)
chain
.
fweight
.
append
(
0
)
chain
.
bweight
.
append
(
0
)
solver
=
PofoSolver
(
chain
,
mem_limit
,
bandwidth
,
mem_slots
)
first_state
=
(
0
,
0
,
0
,
0
,
False
)
sequence
=
solver
.
pofo_rec
(
first_state
)
if
sequence
==
None
:
raise
ValueError
(
f
"Cannot solve sequence with
{
mem_limit
}
Bytes memory"
)
_annotate_from_pofo_sequence
(
sequence
,
node_list
)
setattr
(
gm
,
"__sequence__"
,
sequence
)
return
gm
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
deleted
100644 → 0
View file @
55dcd305
import
math
import
sys
from
typing
import
List
,
Tuple
from
torch.fx
import
Node
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
_find_nested_ckpt_regions
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.profiler
import
activation_size
,
calculate_fwd_out
,
calculate_fwd_tmp
,
parameter_size
from
colossalai.logging
import
get_dist_logger
from
.linearize
import
linearize
from
.operation
import
Backward
,
Chain
,
ForwardCheck
,
ForwardEnable
,
ForwardNograd
,
Function
,
Loss
,
Sequence
# global vairable to indicate whether the solver is failed
SOLVER_FAILED
=
False
# this is the python compute table code from rotor
# https://gitlab.inria.fr/hiepacs/rotor
# paper link: https://hal.inria.fr/hal-02352969
def
_compute_table
(
chain
:
Chain
,
mmax
)
->
Tuple
:
"""Returns the optimal table: a tuple containing:
Opt[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
(False, j) if the optimal choice is a leaf checkpoint of length j
The computation uses dynamic programming"""
fw
=
chain
.
fweight
+
[
0
]
## forward time
bw
=
chain
.
bweight
## backward time, not used
cw
=
chain
.
cweight
+
[
0
]
## size of x (and of y)
cbw
=
chain
.
cbweight
+
[
0
]
## size of xbar
fwd_mem_tmp
=
chain
.
fwd_mem_tmp
+
[
0
]
bwd_mem_tmp
=
chain
.
bwd_mem_tmp
+
[
0
]
# Build table
opt
=
[[{}
for
_
in
range
(
chain
.
length
+
1
)]
for
_
in
range
(
mmax
+
1
)]
what
=
[[{}
for
_
in
range
(
chain
.
length
+
1
)]
for
_
in
range
(
mmax
+
1
)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for
m
in
range
(
mmax
+
1
):
for
i
in
range
(
chain
.
length
+
1
):
#lmax-lmin = 0
limit
=
max
(
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
fwd_mem_tmp
[
i
],
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
bwd_mem_tmp
[
i
])
if
m
>=
limit
:
## Equation (1)
opt
[
m
][
i
][
i
]
=
fw
[
i
]
+
bw
[
i
]
else
:
opt
[
m
][
i
][
i
]
=
float
(
"inf"
)
# Compute everything
for
m
in
range
(
mmax
+
1
):
for
d
in
range
(
1
,
chain
.
length
+
1
):
for
i
in
range
(
chain
.
length
+
1
-
d
):
# for idx in range(i+1, chain.length + 1):
idx
=
i
+
d
mmin
=
cw
[
idx
+
1
]
+
cw
[
i
+
1
]
+
fwd_mem_tmp
[
i
]
if
idx
>
i
+
1
:
mmin
=
max
(
mmin
,
cw
[
idx
+
1
]
+
max
(
cw
[
j
]
+
cw
[
j
+
1
]
+
fwd_mem_tmp
[
j
]
for
j
in
range
(
i
+
1
,
idx
)))
if
m
<
mmin
:
opt
[
m
][
i
][
idx
]
=
float
(
"inf"
)
else
:
leaf_checkpoints
=
[(
j
,
sum
(
fw
[
i
:
j
])
+
opt
[
m
-
cw
[
j
]][
j
][
idx
]
+
opt
[
m
][
i
][
j
-
1
])
for
j
in
range
(
i
+
1
,
idx
+
1
)
if
m
>=
cw
[
j
]]
if
leaf_checkpoints
:
best_leaf
=
min
(
leaf_checkpoints
,
key
=
lambda
t
:
t
[
1
])
else
:
best_leaf
=
None
if
m
>=
cbw
[
i
+
1
]:
chain_checkpoint
=
opt
[
m
][
i
][
i
]
+
opt
[
m
-
cbw
[
i
+
1
]][
i
+
1
][
idx
]
else
:
chain_checkpoint
=
float
(
"inf"
)
if
best_leaf
and
best_leaf
[
1
]
<=
chain_checkpoint
:
opt
[
m
][
i
][
idx
]
=
best_leaf
[
1
]
what
[
m
][
i
][
idx
]
=
(
False
,
best_leaf
[
0
])
else
:
opt
[
m
][
i
][
idx
]
=
chain_checkpoint
what
[
m
][
i
][
idx
]
=
(
True
,)
return
(
opt
,
what
)
def
_rec
(
chain
:
Chain
,
lmin
,
lmax
,
cmem
,
opt_table
):
""" chain : the class describing the AC graph
lmin : index of the first forward to execute
lmax : upper bound index of the last forward to execute (not included)
cmem : number of available memory slots
Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]"""
if
cmem
<=
0
:
raise
ValueError
(
"Can not process a chain with negative memory {cmem}"
.
format
(
cmem
=
cmem
))
opt
,
what
=
opt_table
sequence
=
Sequence
(
Function
(
"Persistent"
,
lmax
-
lmin
,
cmem
))
if
opt
[
cmem
][
lmin
][
lmax
]
==
float
(
"inf"
):
# using logger to annonce that the solver is failed
logger
=
get_dist_logger
()
logger
.
info
(
"Can not process this chain from index {lmin} to {lmax} with memory {cmem}"
.
format
(
lmin
=
lmin
,
lmax
=
lmax
,
cmem
=
cmem
))
# set global indicater SOLVER_FAILED to True
global
SOLVER_FAILED
SOLVER_FAILED
=
True
return
sequence
if
lmin
==
lmax
:
if
lmin
==
chain
.
length
:
sequence
.
insert
(
Loss
())
else
:
sequence
.
insert
(
ForwardEnable
(
lmin
))
sequence
.
insert
(
Backward
(
lmin
))
return
sequence
if
what
[
cmem
][
lmin
][
lmax
][
0
]:
sequence
.
insert
(
ForwardEnable
(
lmin
))
sequence
.
insert_sequence
(
_rec
(
chain
,
lmin
+
1
,
lmax
,
cmem
-
chain
.
cbweight
[
lmin
+
1
],
opt_table
))
sequence
.
insert
(
Backward
(
lmin
))
else
:
j
=
what
[
cmem
][
lmin
][
lmax
][
1
]
sequence
.
insert
(
ForwardCheck
(
lmin
))
for
k
in
range
(
lmin
+
1
,
j
):
sequence
.
insert
(
ForwardNograd
(
k
))
sequence
.
insert_sequence
(
_rec
(
chain
,
j
,
lmax
,
cmem
-
chain
.
cweight
[
j
],
opt_table
))
sequence
.
insert_sequence
(
_rec
(
chain
,
lmin
,
j
-
1
,
cmem
,
opt_table
))
return
sequence
def
_fwd_xbar
(
node
:
List
[
Node
])
->
int
:
"""Get the forward xbar of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: xbar size, unit Byte
"""
xbar
=
0
for
n
in
node
:
xbar
+=
calculate_fwd_tmp
(
n
)
+
calculate_fwd_out
(
n
)
return
xbar
def
_fwd_time
(
node
:
List
[
Node
])
->
int
:
"""Get the foward time of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: foward time, extimated by flops count
"""
fwd_time
=
0
for
n
in
node
:
# minimum flop count is needed
fwd_time
+=
max
(
n
.
meta
[
'fwd_flop'
],
1
)
return
fwd_time
def
_bwd_time
(
node
:
List
[
Node
])
->
int
:
"""Get the backward time of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: backward time, extimated by flops count
"""
bwd_time
=
0
for
n
in
node
:
# minimum flop count is needed
bwd_time
+=
max
(
n
.
meta
[
'bwd_flop'
],
1
)
return
bwd_time
def
_get_fwd_mem_tmp
(
node
:
List
[
Node
])
->
int
:
"""Get the forward temp memory of a node
This could be done by subtracting the saved activation from all output of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: forward temp memory, unit Byte
"""
n
=
node
[
-
1
]
return
activation_size
(
n
.
meta
[
'fwd_out'
])
-
calculate_fwd_out
(
n
)
def
_get_bwd_mem_tmp
(
node
:
List
[
Node
])
->
int
:
"""Get the backward temp memory of a node
Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
int: backward temp memory, unit Byte
"""
def
_get_deps_size
():
deps_size
=
0
for
k
,
v
in
deps
.
items
():
k
:
Node
if
v
>
0
:
deps_size
+=
k
.
meta
[
'bwd_mem_out'
]
if
v
==
float
(
'-inf'
):
deps_size
-=
calculate_fwd_tmp
(
k
)
+
calculate_fwd_out
(
k
)
return
deps_size
bwd_mem_tmp
=
0
deps
=
{}
for
n
in
reversed
(
node
):
deps
[
n
]
=
len
(
n
.
all_input_nodes
)
bwd_mem_tmp
=
max
(
bwd_mem_tmp
,
_get_deps_size
()
+
n
.
meta
[
'bwd_mem_tmp'
])
for
child
in
n
.
users
:
if
child
in
deps
:
deps
[
child
]
-=
1
if
deps
[
child
]
<=
0
:
deps
[
child
]
=
float
(
'-inf'
)
# free
return
bwd_mem_tmp
def
_construct_chain
(
node_list
:
List
[
List
[
Node
]],
input
)
->
Chain
:
fwd_time
=
[]
bwd_time
=
[]
xbar_sizes
=
[
activation_size
(
input
)]
x_sizes
=
[
activation_size
(
input
)]
tmp_fwd
=
[]
tmp_bwd
=
[]
for
idx
,
node
in
enumerate
(
node_list
):
fwd_time
.
append
(
_fwd_time
(
node
))
bwd_time
.
append
(
_bwd_time
(
node
))
x_sizes
.
append
(
calculate_fwd_out
(
node
[
-
1
]))
xbar_sizes
.
append
(
max
(
x_sizes
[
-
1
],
_fwd_xbar
(
node
)))
tmp_fwd
.
append
(
_get_fwd_mem_tmp
(
node
))
tmp_bwd
.
append
(
_get_bwd_mem_tmp
(
node
))
bwd_time
.
append
(
0
)
# currently we view loss backward temp as zero
tmp_bwd
.
append
(
0
)
return
Chain
(
fwd_time
,
bwd_time
,
x_sizes
,
xbar_sizes
,
tmp_fwd
,
tmp_bwd
)
def
_annotate_from_sequence
(
sequence
:
Sequence
,
node_list
:
List
[
List
[
Node
]]):
op_list
=
sequence
.
list_operations
()
loss_op
=
next
(
op
for
op
in
op_list
if
isinstance
(
op
,
Loss
))
fwd_list
=
op_list
[:
op_list
.
index
(
loss_op
)]
bwd_list
=
op_list
[
op_list
.
index
(
loss_op
)
+
1
:]
ckpt_idx
=
0
in_ckpt
=
False
ckpt_region
=
[]
# forward annotation
for
idx
,
op
in
enumerate
(
fwd_list
,
0
):
if
in_ckpt
:
if
isinstance
(
op
,
ForwardNograd
):
ckpt_region
.
append
(
idx
)
elif
isinstance
(
op
,
ForwardEnable
):
in_ckpt
=
False
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
setattr
(
n
,
"activation_checkpoint"
,
[
ckpt_idx
])
ckpt_idx
+=
1
ckpt_region
=
[]
elif
isinstance
(
op
,
ForwardCheck
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
setattr
(
n
,
"activation_checkpoint"
,
[
ckpt_idx
])
ckpt_idx
+=
1
ckpt_region
=
[
idx
]
else
:
if
isinstance
(
op
,
ForwardCheck
):
in_ckpt
=
True
ckpt_region
.
append
(
idx
)
# annotate the backward if there is any nested activation checkpoint
in_recompute
=
False
for
op
in
bwd_list
:
if
in_recompute
:
if
isinstance
(
op
,
ForwardNograd
):
ckpt_region
.
append
(
op
.
index
)
elif
isinstance
(
op
,
ForwardEnable
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
activation_checkpoint
.
append
(
ckpt_idx
)
ckpt_idx
+=
1
ckpt_region
=
[]
elif
isinstance
(
op
,
ForwardCheck
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
activation_checkpoint
.
append
(
ckpt_idx
)
ckpt_idx
+=
1
ckpt_region
=
[
op
.
index
]
elif
isinstance
(
op
,
Backward
):
for
node_idx
in
ckpt_region
:
for
n
in
node_list
[
node_idx
]:
n
.
activation_checkpoint
.
append
(
ckpt_idx
)
in_recompute
=
False
else
:
if
not
isinstance
(
op
,
Backward
):
in_recompute
=
True
ckpt_idx
=
0
ckpt_region
=
[]
if
isinstance
(
op
,
ForwardCheck
):
ckpt_region
.
append
(
op
.
index
)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list
=
[]
for
node
in
node_list
:
op_list
+=
node
ckpt_regions
=
_find_nested_ckpt_regions
(
op_list
)
for
(
start_idx
,
end_idx
)
in
ckpt_regions
:
nested_length
=
max
(
len
(
op_list
[
idx
].
activation_checkpoint
)
for
idx
in
range
(
start_idx
,
end_idx
+
1
))
for
idx
in
range
(
start_idx
,
end_idx
+
1
):
op_list
[
idx
].
activation_checkpoint
+=
[
None
]
*
(
nested_length
-
len
(
op_list
[
idx
].
activation_checkpoint
))
def
solver_rotor
(
gm
:
ColoGraphModule
,
data
,
mem_limit
:
int
,
mem_slots
:
int
=
500
,
cnode
:
List
[
str
]
=
None
,
eps
:
float
=
0.0
,
force_python
:
bool
=
False
)
->
ColoGraphModule
:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp.
data (torch.Tensor): input data.
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.0
force_python (bool): force to use python version of dynamic programs
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
"""
# try to import C version solver if force_python is not set
logger
=
get_dist_logger
()
if
not
force_python
:
try
:
from
.dynamic_programs_C_version
import
persistent_compute_table
CVERSION
=
True
# build module if module not found
except
ModuleNotFoundError
:
import
os
import
subprocess
logger
.
info
(
"dynamic_programs_C_version hasn't been built! Building library..."
,
ranks
=
[
0
])
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
result
=
subprocess
.
Popen
(
[
f
"
{
sys
.
executable
}
"
,
f
"
{
os
.
path
.
join
(
this_dir
,
'build_c_ext.py'
)
}
"
,
"build_ext"
,
f
"--build-lib=
{
this_dir
}
"
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
)
if
result
.
wait
()
==
0
:
logger
.
info
(
"dynamic_programs_C_version has been built!"
,
ranks
=
[
0
])
from
.dynamic_programs_C_version
import
persistent_compute_table
CVERSION
=
True
else
:
logger
.
info
(
"dynamic_programs_C_version built failed! Using python version!"
,
ranks
=
[
0
])
CVERSION
=
False
else
:
CVERSION
=
False
# check if metainfoprop is done
if
any
(
len
(
node
.
meta
)
==
0
for
node
in
gm
.
graph
.
nodes
):
raise
RuntimeError
(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!"
)
# linearize the graph
node_list
=
linearize
(
gm
,
cnode
)
# construct chain
mem_unit
=
mem_limit
*
(
1.0
-
eps
)
//
mem_slots
chain
:
Chain
=
_construct_chain
(
node_list
,
data
)
chain
.
_discretize
(
mem_unit
)
# use C version if possible
if
CVERSION
and
not
force_python
:
logger
.
info
(
"Using C version rotor solver!"
,
ranks
=
[
0
])
opt_table
=
persistent_compute_table
(
chain
,
mem_slots
)
else
:
opt_table
=
_compute_table
(
chain
,
mem_slots
)
logger
.
info
(
"Using python version rotor solver!"
,
ranks
=
[
0
])
# found sequence
sequence
=
_rec
(
chain
,
0
,
chain
.
length
,
mem_slots
-
chain
.
cweight
[
0
],
opt_table
)
# if solver failed, we don't need to annotate the graph
if
not
SOLVER_FAILED
:
_annotate_from_sequence
(
sequence
,
node_list
)
# set __sequence__ attribute to GraphModule
if
SOLVER_FAILED
:
setattr
(
gm
,
"__sequence__"
,
None
)
else
:
setattr
(
gm
,
"__sequence__"
,
sequence
)
# set __opttable__ attribute to GraphModule
setattr
(
gm
,
"__opttable__"
,
opt_table
[
0
])
gm
.
recompile
()
return
gm
colossalai/fx/passes/algorithms/dynamic_programs.c
deleted
100644 → 0
View file @
55dcd305
#define PY_SSIZE_T_CLEAN
#include <Python.h>
long
*
PySequenceToLongArray
(
PyObject
*
pylist
)
{
if
(
!
(
pylist
&&
PySequence_Check
(
pylist
)))
return
NULL
;
Py_ssize_t
len
=
PySequence_Size
(
pylist
);
long
*
result
=
(
long
*
)
calloc
(
len
+
1
,
sizeof
(
long
));
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
++
i
)
{
PyObject
*
item
=
PySequence_GetItem
(
pylist
,
i
);
result
[
i
]
=
PyLong_AsLong
(
item
);
Py_DECREF
(
item
);
}
result
[
len
]
=
0
;
return
result
;
}
double
*
PySequenceToDoubleArray
(
PyObject
*
pylist
)
{
if
(
!
(
pylist
&&
PySequence_Check
(
pylist
)))
return
NULL
;
Py_ssize_t
len
=
PySequence_Size
(
pylist
);
double
*
result
=
(
double
*
)
calloc
(
len
+
1
,
sizeof
(
double
));
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
++
i
)
{
PyObject
*
item
=
PySequence_GetItem
(
pylist
,
i
);
result
[
i
]
=
PyFloat_AsDouble
(
item
);
Py_DECREF
(
item
);
}
result
[
len
]
=
0
;
return
result
;
}
long
*
getLongArray
(
PyObject
*
container
,
const
char
*
attributeName
)
{
PyObject
*
sequence
=
PyObject_GetAttrString
(
container
,
attributeName
);
long
*
result
=
PySequenceToLongArray
(
sequence
);
Py_DECREF
(
sequence
);
return
result
;
}
double
*
getDoubleArray
(
PyObject
*
container
,
const
char
*
attributeName
)
{
PyObject
*
sequence
=
PyObject_GetAttrString
(
container
,
attributeName
);
double
*
result
=
PySequenceToDoubleArray
(
sequence
);
Py_DECREF
(
sequence
);
return
result
;
}
static
PyObject
*
persistent_compute_table
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
chain_param
;
int
mmax
;
if
(
!
PyArg_ParseTuple
(
args
,
"Oi"
,
&
chain_param
,
&
mmax
))
return
NULL
;
double
*
fw
=
getDoubleArray
(
chain_param
,
"fweight"
);
if
(
!
fw
)
return
NULL
;
double
*
bw
=
getDoubleArray
(
chain_param
,
"bweight"
);
if
(
!
bw
)
return
NULL
;
long
*
cw
=
getLongArray
(
chain_param
,
"cweight"
);
if
(
!
cw
)
return
NULL
;
long
*
cbw
=
getLongArray
(
chain_param
,
"cbweight"
);
if
(
!
cbw
)
return
NULL
;
long
*
fwd_tmp
=
getLongArray
(
chain_param
,
"fwd_mem_tmp"
);
if
(
!
cbw
)
return
NULL
;
long
*
bwd_tmp
=
getLongArray
(
chain_param
,
"bwd_mem_tmp"
);
if
(
!
cbw
)
return
NULL
;
PyObject
*
chain_length_param
=
PyObject_GetAttrString
(
chain_param
,
"length"
);
if
(
!
chain_length_param
)
return
NULL
;
long
chain_length
=
PyLong_AsLong
(
chain_length_param
);
Py_DECREF
(
chain_length_param
);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double
*
opt
=
(
double
*
)
calloc
(
(
mmax
+
1
)
*
(
chain_length
+
1
)
*
(
chain_length
+
1
),
sizeof
(
double
));
#define WHAT(m, i, l) \
what[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
long
*
what
=
(
long
*
)
calloc
(
(
mmax
+
1
)
*
(
chain_length
+
1
)
*
(
chain_length
+
1
),
sizeof
(
long
));
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
// TODO: Can be optimized to remove the IF by reordering loops
if
((
m
>=
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
bwd_tmp
[
i
])
&&
(
m
>=
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
fwd_tmp
[
i
]))
OPT
(
m
,
i
,
i
)
=
fw
[
i
]
+
bw
[
i
];
else
OPT
(
m
,
i
,
i
)
=
INFINITY
;
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
d
=
1
;
d
<=
chain_length
;
++
d
)
{
for
(
long
i
=
0
;
i
<=
chain_length
-
d
;
++
i
)
{
long
idx
=
i
+
d
;
long
mmin
=
cw
[
idx
+
1
]
+
cw
[
i
+
1
]
+
fwd_tmp
[
i
];
if
(
idx
>
i
+
1
)
{
long
maxCostFWD
=
0
;
for
(
long
j
=
i
+
1
;
j
<
idx
;
j
++
)
{
maxCostFWD
=
fmaxl
(
maxCostFWD
,
cw
[
j
]
+
cw
[
j
+
1
]
+
fwd_tmp
[
j
]);
}
mmin
=
fmaxl
(
mmin
,
cw
[
idx
+
1
]
+
maxCostFWD
);
}
if
((
m
>=
mmin
))
{
long
bestLeaf
=
-
1
;
double
sumFw
=
0
;
double
bestLeafCost
=
INFINITY
;
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
/// i+1
for
(
long
j
=
i
+
1
;
j
<=
idx
;
++
j
)
{
sumFw
+=
fw
[
j
-
1
];
if
(
m
>=
cw
[
j
])
{
double
cost
=
sumFw
+
OPT
(
m
-
cw
[
j
],
j
,
idx
)
+
OPT
(
m
,
i
,
j
-
1
);
if
(
cost
<
bestLeafCost
)
{
bestLeafCost
=
cost
;
bestLeaf
=
j
;
}
}
}
double
chainCost
=
INFINITY
;
if
(
m
>=
cbw
[
i
+
1
])
chainCost
=
OPT
(
m
,
i
,
i
)
+
OPT
(
m
-
cbw
[
i
+
1
],
i
+
1
,
idx
);
if
(
bestLeafCost
<=
chainCost
)
{
OPT
(
m
,
i
,
idx
)
=
bestLeafCost
;
WHAT
(
m
,
i
,
idx
)
=
bestLeaf
;
}
else
{
OPT
(
m
,
i
,
idx
)
=
chainCost
;
WHAT
(
m
,
i
,
idx
)
=
-
1
;
}
}
else
OPT
(
m
,
i
,
idx
)
=
INFINITY
;
}
}
free
(
fw
);
free
(
bw
);
free
(
cw
);
free
(
cbw
);
free
(
fwd_tmp
);
free
(
bwd_tmp
);
PyObject
*
res_opt
=
PyList_New
(
mmax
+
1
);
PyObject
*
res_what
=
PyList_New
(
mmax
+
1
);
// Convert the result into Python world
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
{
PyObject
*
res_opt_m
=
PyList_New
(
chain_length
+
1
);
PyList_SET_ITEM
(
res_opt
,
m
,
res_opt_m
);
PyObject
*
res_what_m
=
PyList_New
(
chain_length
+
1
);
PyList_SET_ITEM
(
res_what
,
m
,
res_what_m
);
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
{
PyObject
*
res_opt_m_i
=
PyDict_New
();
PyList_SET_ITEM
(
res_opt_m
,
i
,
res_opt_m_i
);
PyObject
*
res_what_m_i
=
PyDict_New
();
PyList_SET_ITEM
(
res_what_m
,
i
,
res_what_m_i
);
for
(
long
l
=
i
;
l
<=
chain_length
;
++
l
)
{
PyObject
*
res_l
=
PyLong_FromLong
(
l
);
PyObject
*
res_opt_m_i_l
=
PyFloat_FromDouble
(
OPT
(
m
,
i
,
l
));
PyDict_SetItem
(
res_opt_m_i
,
res_l
,
res_opt_m_i_l
);
Py_DECREF
(
res_opt_m_i_l
);
PyObject
*
res_what_m_i_l
;
long
what_m_i_l
=
WHAT
(
m
,
i
,
l
);
if
(
what_m_i_l
<
0
)
res_what_m_i_l
=
Py_BuildValue
(
"(O)"
,
Py_True
);
else
res_what_m_i_l
=
Py_BuildValue
(
"(Ol)"
,
Py_False
,
what_m_i_l
);
PyDict_SetItem
(
res_what_m_i
,
res_l
,
res_what_m_i_l
);
Py_DECREF
(
res_what_m_i_l
);
Py_DECREF
(
res_l
);
}
}
}
free
(
opt
);
free
(
what
);
PyObject
*
result
=
PyTuple_Pack
(
2
,
res_opt
,
res_what
);
Py_DECREF
(
res_opt
);
Py_DECREF
(
res_what
);
return
result
;
}
// long i = L - s, j = t - s, k = l - t
inline
long
floating_index_in_array
(
long
m_factor
,
long
m
,
long
i
,
long
j
,
long
k
)
{
return
m
*
m_factor
+
(
i
*
(
i
+
1
)
*
(
2
*
i
+
4
))
/
12
+
(
i
+
1
)
*
j
-
(
j
*
(
j
-
1
))
/
2
+
k
;
}
typedef
struct
{
long
sp
;
long
r
;
long
tp
;
}
index_t
;
static
PyObject
*
floating_compute_table
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
chain_param
;
int
mmax
;
if
(
!
PyArg_ParseTuple
(
args
,
"Oi"
,
&
chain_param
,
&
mmax
))
return
NULL
;
double
*
fw
=
getDoubleArray
(
chain_param
,
"fweigth"
);
if
(
!
fw
)
return
NULL
;
double
*
bw
=
getDoubleArray
(
chain_param
,
"bweigth"
);
if
(
!
bw
)
return
NULL
;
long
*
cw
=
getLongArray
(
chain_param
,
"cweigth"
);
if
(
!
cw
)
return
NULL
;
long
*
cbw
=
getLongArray
(
chain_param
,
"cbweigth"
);
if
(
!
cbw
)
return
NULL
;
long
*
fwd_tmp
=
getLongArray
(
chain_param
,
"fwd_tmp"
);
if
(
!
fwd_tmp
)
return
NULL
;
long
*
bwd_tmp
=
getLongArray
(
chain_param
,
"bwd_tmp"
);
if
(
!
bwd_tmp
)
return
NULL
;
PyObject
*
chain_length_param
=
PyObject_GetAttrString
(
chain_param
,
"length"
);
if
(
!
chain_length_param
)
return
NULL
;
long
chain_length
=
PyLong_AsLong
(
chain_length_param
);
Py_DECREF
(
chain_length_param
);
const
long
m_factor
=
(
chain_length
+
1
)
*
(
chain_length
+
2
)
*
(
2
*
chain_length
+
6
)
/
12
;
// Defined for 0 <= s <= t <= l <= chain_length, for all m
#undef OPT
#define OPT(m, s, t, l) \
opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
double
*
opt
=
(
double
*
)
calloc
((
mmax
+
1
)
*
m_factor
,
sizeof
(
double
));
#undef WHAT
#define WHAT(m, s, t, l) \
what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
(l) - (t))]
index_t
*
what
=
(
index_t
*
)
calloc
((
mmax
+
1
)
*
m_factor
,
sizeof
(
index_t
));
double
*
partialSumsFW
=
(
double
*
)
calloc
(
chain_length
+
1
,
sizeof
(
double
));
double
total
=
0
;
for
(
long
i
=
0
;
i
<
chain_length
;
++
i
)
{
partialSumsFW
[
i
]
=
total
;
total
+=
fw
[
i
];
}
partialSumsFW
[
chain_length
]
=
total
;
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
{
// TODO: Can be optimized to remove the IF by reordering loops
if
((
m
>=
cw
[
i
]
+
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
bwd_tmp
[
i
])
&&
(
m
>=
cw
[
i
+
1
]
+
cbw
[
i
+
1
]
+
fwd_tmp
[
i
]))
OPT
(
m
,
i
,
i
,
i
)
=
fw
[
i
]
+
bw
[
i
];
else
OPT
(
m
,
i
,
i
,
i
)
=
INFINITY
;
}
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
d
=
1
;
d
<=
chain_length
;
++
d
)
{
// d = l - s
for
(
long
s
=
0
;
s
<=
chain_length
-
d
;
++
s
)
{
long
l
=
s
+
d
;
long
memNullFirst
=
cw
[
l
+
1
]
+
cw
[
s
+
1
]
+
fwd_tmp
[
s
];
long
memNullSecond
=
0
;
for
(
long
j
=
s
+
1
;
j
<
l
;
++
j
)
{
long
val
=
cw
[
j
]
+
cw
[
j
+
1
]
+
fwd_tmp
[
j
];
if
(
val
>
memNullSecond
)
memNullSecond
=
val
;
}
for
(
long
t
=
s
;
t
<=
l
;
++
t
)
{
double
chainCost
=
INFINITY
;
if
((
s
==
t
)
&&
(
m
>=
cw
[
l
+
1
]
+
cbw
[
s
+
1
]
+
fwd_tmp
[
s
])
&&
(
m
>=
cw
[
s
]
+
cw
[
s
+
1
]
+
cbw
[
s
+
1
]
+
bwd_tmp
[
s
]))
{
chainCost
=
OPT
(
m
,
s
,
s
,
s
)
+
OPT
(
m
-
cbw
[
s
+
1
],
s
+
1
,
s
+
1
,
l
);
}
double
bestLeafCost
=
INFINITY
;
index_t
bestLeaf
=
{.
sp
=
-
1
,
.
r
=
-
1
,
.
tp
=
-
1
};
if
(
m
>=
memNullFirst
&&
m
>=
cw
[
l
+
1
]
+
memNullSecond
)
{
for
(
long
r
=
s
;
r
<=
t
;
++
r
)
if
(
cw
[
s
]
<=
cw
[
r
])
for
(
long
tp
=
t
+
1
;
tp
<=
l
;
++
tp
)
for
(
long
sp
=
r
+
1
;
sp
<=
tp
;
++
sp
)
{
long
mp
=
m
-
cw
[
r
]
+
cw
[
s
];
assert
(
mp
>=
0
);
if
(
mp
>=
cw
[
sp
])
{
double
value
=
partialSumsFW
[
sp
]
-
partialSumsFW
[
s
]
+
OPT
(
mp
-
cw
[
sp
],
sp
,
tp
,
l
)
+
OPT
(
mp
,
r
,
t
,
tp
-
1
);
if
(
value
<
bestLeafCost
)
{
bestLeafCost
=
value
;
bestLeaf
.
sp
=
sp
;
bestLeaf
.
r
=
r
;
bestLeaf
.
tp
=
tp
;
}
}
}
}
if
(
bestLeaf
.
sp
>=
0
&&
bestLeafCost
<=
chainCost
)
{
OPT
(
m
,
s
,
t
,
l
)
=
bestLeafCost
;
WHAT
(
m
,
s
,
t
,
l
).
sp
=
bestLeaf
.
sp
;
WHAT
(
m
,
s
,
t
,
l
).
r
=
bestLeaf
.
r
;
WHAT
(
m
,
s
,
t
,
l
).
tp
=
bestLeaf
.
tp
;
}
else
{
OPT
(
m
,
s
,
t
,
l
)
=
chainCost
;
WHAT
(
m
,
s
,
t
,
l
).
sp
=
-
1
;
}
}
}
}
free
(
fw
);
free
(
bw
);
free
(
cw
);
free
(
cbw
);
free
(
fwd_tmp
);
free
(
bwd_tmp
);
PyObject
*
res_opt
=
PyList_New
(
mmax
+
1
);
PyObject
*
res_what
=
PyList_New
(
mmax
+
1
);
// Convert the result into Python world
PyObject
*
true_tuple
=
Py_BuildValue
(
"(O)"
,
Py_True
);
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
{
PyObject
*
res_opt_m
=
PyDict_New
();
PyList_SET_ITEM
(
res_opt
,
m
,
res_opt_m
);
PyObject
*
res_what_m
=
PyDict_New
();
PyList_SET_ITEM
(
res_what
,
m
,
res_what_m
);
for
(
long
s
=
0
;
s
<=
chain_length
;
++
s
)
for
(
long
t
=
s
;
t
<=
chain_length
;
++
t
)
for
(
long
l
=
t
;
l
<=
chain_length
;
++
l
)
{
PyObject
*
key
=
Py_BuildValue
(
"(lll)"
,
s
,
t
,
l
);
PyObject
*
value_opt
=
PyFloat_FromDouble
(
OPT
(
m
,
s
,
t
,
l
));
PyDict_SetItem
(
res_opt_m
,
key
,
value_opt
);
PyObject
*
value_what
=
true_tuple
;
index_t
*
idx_what
=
&
WHAT
(
m
,
s
,
t
,
l
);
if
(
idx_what
->
sp
>=
0
)
value_what
=
Py_BuildValue
(
"(O(lll))"
,
Py_False
,
idx_what
->
sp
,
idx_what
->
r
,
idx_what
->
tp
);
PyDict_SetItem
(
res_what_m
,
key
,
value_what
);
if
(
value_what
!=
true_tuple
)
Py_DECREF
(
value_what
);
Py_DECREF
(
key
);
Py_DECREF
(
value_opt
);
}
}
Py_DECREF
(
true_tuple
);
free
(
opt
);
free
(
what
);
PyObject
*
result
=
PyTuple_Pack
(
2
,
res_opt
,
res_what
);
Py_DECREF
(
res_opt
);
Py_DECREF
(
res_what
);
return
result
;
}
static
PyObject
*
griewank_heterogeneous_compute_table
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObject
*
chain_param
;
int
mmax
;
if
(
!
PyArg_ParseTuple
(
args
,
"Oi"
,
&
chain_param
,
&
mmax
))
return
NULL
;
double
*
fw
=
getDoubleArray
(
chain_param
,
"fweigth"
);
if
(
!
fw
)
return
NULL
;
double
*
bw
=
getDoubleArray
(
chain_param
,
"bweigth"
);
if
(
!
bw
)
return
NULL
;
long
*
cw
=
getLongArray
(
chain_param
,
"cweigth"
);
if
(
!
cw
)
return
NULL
;
long
*
cbw
=
getLongArray
(
chain_param
,
"cbweigth"
);
if
(
!
cbw
)
return
NULL
;
PyObject
*
chain_length_param
=
PyObject_GetAttrString
(
chain_param
,
"length"
);
if
(
!
chain_length_param
)
return
NULL
;
long
chain_length
=
PyLong_AsLong
(
chain_length_param
);
Py_DECREF
(
chain_length_param
);
// TODO: Can be optimized by only allocating memory for l >= i
// TODO: float / int instead of double / long ?
#undef OPT
#define OPT(m, i, l) \
opt[(m) * (chain_length + 1) * (chain_length + 1) + \
(i) * (chain_length + 1) + (l)]
double
*
opt
=
(
double
*
)
calloc
(
(
mmax
+
1
)
*
(
chain_length
+
1
)
*
(
chain_length
+
1
),
sizeof
(
double
));
// Compute partial sums
double
*
sumfw
=
(
double
*
)
calloc
(
chain_length
,
sizeof
(
double
));
double
*
sumbw
=
(
double
*
)
calloc
(
chain_length
+
1
,
sizeof
(
double
));
double
*
sumsumfw
=
(
double
*
)
calloc
(
chain_length
,
sizeof
(
double
));
double
total
=
0
;
for
(
long
i
=
0
;
i
<
chain_length
;
++
i
)
{
total
+=
fw
[
i
];
sumfw
[
i
]
=
total
;
}
total
=
0
;
for
(
long
i
=
0
;
i
<
chain_length
+
1
;
++
i
)
{
total
+=
bw
[
i
];
sumbw
[
i
]
=
total
;
}
total
=
0
;
for
(
long
i
=
0
;
i
<
chain_length
;
++
i
)
{
total
+=
sumfw
[
i
];
sumsumfw
[
i
]
=
total
;
}
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
{
// TODO: Can be optimized to remove the IF by reordering loops
if
((
m
>=
cbw
[
i
])
&&
(
m
>=
cw
[
i
]
+
cbw
[
i
+
1
]))
OPT
(
m
,
i
,
i
)
=
bw
[
i
];
else
OPT
(
m
,
i
,
i
)
=
INFINITY
;
if
(
i
<
chain_length
)
{
long
maxC
=
fmaxl
(
cw
[
i
],
cw
[
i
+
1
]);
long
maxCB
=
fmaxl
(
cbw
[
i
+
1
],
cbw
[
i
+
2
]
+
maxC
);
if
((
m
>=
cbw
[
i
])
&&
(
m
>=
cw
[
i
]
+
maxCB
))
OPT
(
m
,
i
,
i
+
1
)
=
fw
[
i
]
+
bw
[
i
]
+
bw
[
i
+
1
];
else
OPT
(
m
,
i
,
i
+
1
)
=
INFINITY
;
}
}
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
for
(
long
i
=
0
;
i
+
2
<=
chain_length
;
++
i
)
{
long
mminCst
=
fmaxl
(
cbw
[
i
],
cbw
[
i
+
1
]
+
cw
[
i
]);
long
maxCW_il
=
fmax
(
fmax
(
cw
[
i
],
cw
[
i
+
1
]),
cw
[
i
+
2
]);
long
maxCostFWD
=
cw
[
i
]
+
cbw
[
i
+
2
]
+
maxCW_il
;
for
(
long
l
=
i
+
2
;
l
<=
chain_length
;
++
l
)
{
maxCW_il
=
fmax
(
maxCW_il
,
cw
[
l
+
1
]);
maxCostFWD
=
fmaxl
(
maxCostFWD
,
cw
[
i
]
+
cw
[
l
+
1
]
+
maxCW_il
);
long
mmin
=
fmaxl
(
mminCst
,
maxCostFWD
);
if
((
m
>=
mmin
))
{
double
noCheckpointCost
=
sumbw
[
l
]
-
(
i
>
0
?
sumbw
[
i
-
1
]
:
0
);
noCheckpointCost
+=
sumsumfw
[
l
-
1
]
-
(
i
>
0
?
sumsumfw
[
i
-
1
]
+
(
l
-
i
)
*
sumfw
[
i
-
1
]
:
0
);
double
valueCost
=
INFINITY
;
if
(
m
>=
cw
[
i
])
{
double
sumFwds
=
0
;
for
(
long
j
=
i
+
1
;
j
<
l
;
++
j
)
{
sumFwds
+=
fw
[
j
-
1
];
valueCost
=
fmin
(
valueCost
,
sumFwds
+
OPT
(
m
-
cw
[
i
],
j
,
l
)
+
OPT
(
m
,
i
,
j
-
1
));
}
}
OPT
(
m
,
i
,
l
)
=
fmin
(
noCheckpointCost
,
valueCost
);
}
else
OPT
(
m
,
i
,
l
)
=
INFINITY
;
}
}
free
(
sumfw
);
free
(
sumbw
);
free
(
sumsumfw
);
free
(
fw
);
free
(
bw
);
free
(
cw
);
free
(
cbw
);
PyObject
*
res_opt
=
PyList_New
(
mmax
+
1
);
// Convert the result into Python world
for
(
long
m
=
0
;
m
<=
mmax
;
++
m
)
{
PyObject
*
res_opt_m
=
PyList_New
(
chain_length
+
1
);
PyList_SET_ITEM
(
res_opt
,
m
,
res_opt_m
);
for
(
long
i
=
0
;
i
<=
chain_length
;
++
i
)
{
PyObject
*
res_opt_m_i
=
PyDict_New
();
PyList_SET_ITEM
(
res_opt_m
,
i
,
res_opt_m_i
);
for
(
long
l
=
i
;
l
<=
chain_length
;
++
l
)
{
PyObject
*
res_l
=
PyLong_FromLong
(
l
-
i
);
PyObject
*
res_opt_m_i_l
=
PyFloat_FromDouble
(
OPT
(
m
,
i
,
l
));
PyDict_SetItem
(
res_opt_m_i
,
res_l
,
res_opt_m_i_l
);
Py_DECREF
(
res_opt_m_i_l
);
Py_DECREF
(
res_l
);
}
}
}
free
(
opt
);
return
res_opt
;
}
static
PyMethodDef
dynamic_programs_methods
[]
=
{
{
"persistent_compute_table"
,
persistent_compute_table
,
METH_VARARGS
,
"Compute the optimal table with the persistent algorithm."
},
{
"floating_compute_table"
,
floating_compute_table
,
METH_VARARGS
,
"Compute the optimal table with the floating algorithm."
},
{
"griewank_heterogeneous_compute_table"
,
griewank_heterogeneous_compute_table
,
METH_VARARGS
,
"Compute the optimal table for the Griewank Heterogeneous Model."
},
{
NULL
,
NULL
,
0
,
NULL
}
/* Sentinel */
};
static
struct
PyModuleDef
dynamic_programs_module
=
{
PyModuleDef_HEAD_INIT
,
"dynamic_programs_C_version"
,
/* name of module */
NULL
,
/* module documentation, may be NULL */
-
1
,
/* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
dynamic_programs_methods
};
PyMODINIT_FUNC
PyInit_dynamic_programs_C_version
(
void
)
{
return
PyModule_Create
(
&
dynamic_programs_module
);
}
colossalai/fx/passes/algorithms/linearize.py
deleted
100644 → 0
View file @
55dcd305
from
typing
import
List
,
Any
from
torch.fx
import
GraphModule
,
Node
from
colossalai.fx.profiler
import
is_inplace
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
COPS
=
[
"getattr"
,
"getitem"
,
"size"
]
def
_is_cop
(
target
:
Any
)
->
bool
:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if
isinstance
(
target
,
str
):
return
target
in
COPS
else
:
return
target
.
__name__
in
COPS
def
linearize
(
gm
:
GraphModule
,
cnode
:
List
[
str
]
=
None
)
->
List
[
List
[
Node
]]:
"""Linearizing the graph
Args:
gm (GraphModule): GraphModule derived by tracing
cnode (List[str], optional): common node List, should be the subset of input. Default to None.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
We merge the inplace ops into the previous node.
"""
def
_is_sink
()
->
bool
:
"""Check if we can free all dependencies
Returns:
bool
"""
return
not
sum
([
v
for
_
,
v
in
deps
.
items
()])
and
not
any
(
map
(
is_inplace
,
n
.
users
))
# make sure that item in cnode is valid
if
cnode
:
for
name
in
cnode
:
try
:
assert
next
(
node
for
node
in
gm
.
graph
.
nodes
if
node
.
name
==
name
).
op
==
"placeholder"
,
\
f
"common node
{
name
}
is not an input of the model"
except
StopIteration
:
raise
ValueError
(
f
"common node name
{
name
}
not in graph"
)
else
:
cnode
=
[]
deps
=
{}
linearized_nodes
=
[]
region
=
[]
for
n
in
gm
.
graph
.
nodes
:
if
n
.
op
!=
"placeholder"
and
n
.
op
!=
"output"
:
for
n_par
in
n
.
_input_nodes
:
if
n_par
.
op
!=
"placeholder"
and
n_par
.
name
not
in
cnode
:
deps
[
n_par
]
-=
1
region
.
append
(
n
)
# if the node could free all dependencies in graph
# we could begin a new node
if
_is_sink
():
linearized_nodes
.
append
(
region
)
region
=
[]
# propagate common node attr if possible
if
len
(
n
.
_input_nodes
)
==
len
([
node
for
node
in
n
.
_input_nodes
if
node
.
name
in
cnode
])
or
_is_cop
(
n
.
target
):
cnode
.
append
(
n
.
name
)
else
:
deps
[
n
]
=
len
([
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
return
linearized_nodes
colossalai/fx/passes/algorithms/operation.py
deleted
100644 → 0
View file @
55dcd305
import
math
def
_discretize
(
mem_unit
,
values
):
return
[
math
.
ceil
(
value
/
mem_unit
)
for
value
in
values
]
class
Chain
:
def
__init__
(
self
,
fw
,
bw
,
cw
,
cbw
,
ftmp
,
btmp
,
check
=
True
):
self
.
fweight
=
fw
self
.
bweight
=
bw
self
.
cweight
=
cw
self
.
cbweight
=
cbw
self
.
fwd_mem_tmp
=
ftmp
self
.
bwd_mem_tmp
=
btmp
self
.
length
=
len
(
fw
)
if
check
and
not
self
.
check_lengths
():
raise
AttributeError
(
"In Chain, input lists do not have consistent lengths"
)
def
check_lengths
(
self
):
return
((
len
(
self
.
fweight
)
==
self
.
length
)
and
(
len
(
self
.
bweight
)
==
self
.
length
+
1
)
and
(
len
(
self
.
cweight
)
==
self
.
length
+
1
)
and
(
len
(
self
.
fwd_mem_tmp
)
==
self
.
length
)
and
(
len
(
self
.
bwd_mem_tmp
)
==
self
.
length
+
1
)
and
(
len
(
self
.
cbweight
)
==
self
.
length
+
1
))
def
__repr__
(
self
):
chain_list
=
[]
for
i
in
range
(
self
.
length
):
chain_list
.
append
((
self
.
fweight
[
i
],
self
.
bweight
[
i
],
self
.
cweight
[
i
],
self
.
cbweight
[
i
],
self
.
fwd_mem_tmp
[
i
],
self
.
bwd_mem_tmp
[
i
]))
i
=
self
.
length
chain_list
.
append
((
None
,
self
.
bweight
[
i
],
self
.
cweight
[
i
],
self
.
cbweight
[
i
],
None
,
self
.
bwd_mem_tmp
[
i
]))
return
chain_list
.
__repr__
()
def
_discretize
(
self
,
mem_unit
):
self
.
cweight
=
_discretize
(
mem_unit
,
self
.
cweight
)
self
.
cbweight
=
_discretize
(
mem_unit
,
self
.
cbweight
)
self
.
fwd_mem_tmp
=
_discretize
(
mem_unit
,
self
.
fwd_mem_tmp
)
self
.
bwd_mem_tmp
=
_discretize
(
mem_unit
,
self
.
bwd_mem_tmp
)
class
Operation
:
def
shift
(
self
,
value
):
if
type
(
self
.
index
)
is
tuple
:
self
.
index
=
tuple
(
x
+
value
for
x
in
self
.
index
)
else
:
self
.
index
+=
value
class
Offload
(
Operation
):
def
__init__
(
self
,
index
,
has_bar
=
False
)
->
None
:
super
().
__init__
()
self
.
index
=
index
self
.
name
=
"Off"
self
.
has_bar
=
has_bar
if
self
.
has_bar
:
self
.
name
+=
"wBar"
def
__repr__
(
self
):
return
f
"
{
self
.
name
}
_
{
self
.
index
}
"
class
Prefetch
(
Operation
):
def
__init__
(
self
,
index
,
has_bar
=
False
)
->
None
:
super
().
__init__
()
self
.
index
=
index
self
.
name
=
"Pre"
self
.
has_bar
=
has_bar
if
self
.
has_bar
:
self
.
name
+=
"wBar"
def
__repr__
(
self
):
return
f
"
{
self
.
name
}
_
{
self
.
index
}
"
class
Forward
(
Operation
):
def
__init__
(
self
,
index
):
self
.
index
=
index
self
.
name
=
"F"
def
__repr__
(
self
):
return
"{n}_{i}"
.
format
(
n
=
self
.
name
,
i
=
self
.
index
)
def
cost
(
self
,
chain
:
Chain
):
if
chain
is
not
None
:
return
chain
.
fweight
[
self
.
index
]
else
:
return
1
class
ForwardEnable
(
Forward
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"Fe"
class
ForwardNograd
(
Forward
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"Fn"
class
ForwardCheck
(
Forward
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"CF"
class
Forwards
(
Operation
):
def
__init__
(
self
,
start
,
end
):
self
.
index
=
(
start
,
end
)
def
__repr__
(
self
):
return
"F_{i}->{j}"
.
format
(
i
=
self
.
index
[
0
],
j
=
self
.
index
[
1
])
def
cost
(
self
,
chain
:
Chain
):
if
chain
is
not
None
:
return
sum
(
chain
.
fweight
[
self
.
index
[
0
]:
self
.
index
[
1
]
+
1
])
else
:
return
(
self
.
index
[
1
]
-
self
.
index
[
0
]
+
1
)
def
isForward
(
op
):
return
type
(
op
)
is
Forward
or
type
(
op
)
is
Forwards
class
Backward
(
Operation
):
def
__init__
(
self
,
index
):
self
.
index
=
index
def
__repr__
(
self
):
return
"B_{i}"
.
format
(
i
=
self
.
index
)
def
cost
(
self
,
chain
:
Chain
):
if
chain
is
not
None
:
return
chain
.
bweight
[
self
.
index
]
else
:
return
1
class
Loss
(
Operation
):
def
__init__
(
self
):
pass
def
__repr__
(
self
):
return
"L"
def
cost
(
self
,
chain
):
return
0
class
MemoryAccess
(
Operation
):
def
__init__
(
self
,
index
):
self
.
index
=
index
def
__repr__
(
self
):
return
"{n}_{i}"
.
format
(
n
=
self
.
name
,
i
=
self
.
index
)
def
cost
(
self
,
chain
:
Chain
):
return
0
class
WriteMemory
(
MemoryAccess
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"WM"
class
ReadMemory
(
MemoryAccess
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"RM"
class
DiscardMemory
(
MemoryAccess
):
def
__init__
(
self
,
index
):
super
().
__init__
(
index
)
self
.
name
=
"DM"
class
Function
:
def
__init__
(
self
,
name
,
*
args
):
self
.
name
=
name
self
.
args
=
args
self
.
str_args
=
','
.
join
(
str
(
v
)
for
v
in
self
.
args
)
def
__repr__
(
self
):
return
"{n}({args})"
.
format
(
n
=
self
.
name
,
args
=
self
.
str_args
)
class
Sequence
:
def
__init__
(
self
,
function
):
self
.
sequence
=
[]
#List of Operation and Sequence
self
.
function
=
function
#Description the function (name and parameters)
def
__repr__
(
self
):
return
repr
(
self
.
list_operations
())
def
list_operations
(
self
):
op_list
=
[]
for
x
in
self
.
sequence
:
if
isinstance
(
x
,
Operation
):
op_list
.
append
(
x
)
else
:
assert
isinstance
(
x
,
Sequence
)
op_list
+=
x
.
list_operations
()
return
op_list
def
insert
(
self
,
operation
):
self
.
sequence
.
append
(
operation
)
def
remove
(
self
,
operation_index
):
del
self
.
sequence
[
operation_index
]
def
insert_sequence
(
self
,
sequence
):
self
.
sequence
.
append
(
sequence
)
def
shift
(
self
,
value
):
for
x
in
self
.
sequence
:
x
.
shift
(
value
)
return
self
def
remove_useless_write
(
self
):
if
self
.
sequence
:
if
isinstance
(
self
.
sequence
[
0
],
WriteMemory
):
self
.
remove
(
0
)
return
self
def
get_makespan
(
self
,
chain
):
return
sum
(
op
.
cost
(
chain
)
for
op
in
self
.
list_operations
())
def
without_suffix
(
self
):
ops
=
self
.
list_operations
()
end_of_first_phase
=
[
i
for
i
in
range
(
len
(
ops
))
if
type
(
ops
[
i
])
is
Loss
][
0
]
try
:
last_idx
=
max
(
i
for
i
in
range
(
end_of_first_phase
)
if
not
type
(
ops
[
i
])
is
ForwardEnable
)
except
ValueError
:
last_idx
=
-
1
if
last_idx
==
end_of_first_phase
-
1
:
return
(
self
,
None
)
chain_length
=
ops
[
end_of_first_phase
-
1
].
index
## Some assumption here about the sequence (finishes with Forward_L
start_of_fwd_enable_chain
=
ops
[
last_idx
+
1
].
index
## And starts with B_L), but should be fine in practice
result
=
Sequence
(
Function
(
"Strip"
,
self
.
function
.
name
,
*
self
.
function
.
args
,
start_of_fwd_enable_chain
))
for
i
in
range
(
last_idx
+
1
):
result
.
insert
(
ops
[
i
])
result
.
insert
(
Loss
())
for
i
in
range
(
chain_length
,
start_of_fwd_enable_chain
-
1
,
-
1
):
position
=
end_of_first_phase
+
1
+
(
chain_length
-
i
)
assert
type
(
ops
[
position
])
is
Backward
assert
ops
[
position
].
index
==
i
for
i
in
range
(
end_of_first_phase
+
1
+
1
+
chain_length
-
start_of_fwd_enable_chain
,
len
(
ops
)):
result
.
insert
(
ops
[
i
])
return
(
result
,
start_of_fwd_enable_chain
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment