Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a6cdbf91
Commit
a6cdbf91
authored
Jan 06, 2023
by
oahzxl
Browse files
seperate trace flow
parent
4748967f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
447 additions
and
424 deletions
+447
-424
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+1
-1
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+28
-25
colossalai/autochunk/select_chunk.py
colossalai/autochunk/select_chunk.py
+2
-1
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+414
-0
colossalai/autochunk/trace_index.py
colossalai/autochunk/trace_index.py
+0
-395
tests/test_autochunk/benchmark_autochunk.py
tests/test_autochunk/benchmark_autochunk.py
+2
-2
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
a6cdbf91
...
...
@@ -167,7 +167,7 @@ def emit_code_with_chunk(
)
# ones like
if
"ones_like"
in
node
.
name
:
meta_node
=
chunk_region_search
.
index_tracer
.
node_list
[
node_idx
]
meta_node
=
chunk_region_search
.
trace_index
.
node_list
[
node_idx
]
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
...
...
colossalai/autochunk/search_chunk.py
View file @
a6cdbf91
import
copy
from
.select_chunk
import
SelectChunk
from
.trace_index
import
TraceIndex
,
ReorderGraph
from
.trace_index
import
TraceIndex
from
.reorder_graph
import
ReorderGraph
from
.estiamte_memory
import
EstimateMemory
from
.trace_flow
import
TraceFlow
from
.utils
import
(
get_node_shape
,
is_non_compute_node
,
...
...
@@ -14,12 +16,13 @@ class SearchChunk(object):
def
__init__
(
self
,
gm
,
max_memory
=
None
,
print_mem
=
False
)
->
None
:
self
.
gm
=
gm
self
.
print_mem
=
print_mem
self
.
index_tracer
=
TraceIndex
(
list
(
gm
.
graph
.
nodes
))
self
.
index_tracer
.
trace_index
()
self
.
reorder_graph
=
ReorderGraph
(
self
.
index_tracer
)
self
.
memory_estimator
=
EstimateMemory
()
self
.
chunk_selector
=
SelectChunk
(
self
.
index_tracer
,
self
.
memory_estimator
,
self
.
reorder_graph
,
max_memory
=
max_memory
self
.
trace_index
=
TraceIndex
(
list
(
gm
.
graph
.
nodes
))
self
.
trace_index
.
trace_index
()
self
.
trace_flow
=
TraceFlow
(
self
.
trace_index
)
self
.
reorder_graph
=
ReorderGraph
(
self
.
trace_index
)
self
.
estimate_memory
=
EstimateMemory
()
self
.
select_chunk
=
SelectChunk
(
self
.
trace_index
,
self
.
estimate_memory
,
self
.
reorder_graph
,
max_memory
=
max_memory
)
def
_find_peak_node
(
self
,
mem_peak
):
...
...
@@ -29,7 +32,7 @@ class SearchChunk(object):
def
_get_free_var
(
self
):
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
index_tracer
.
node_list
):
for
idx
,
n
in
enumerate
(
self
.
trace_index
.
node_list
):
if
n
.
op
==
"placeholder"
:
free_var_idx
.
append
(
idx
)
return
free_var_idx
...
...
@@ -99,7 +102,7 @@ class SearchChunk(object):
def
_find_free_dim
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
):
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
index_tracer
.
node_list
[
end_idx
]
end_node
=
self
.
trace_index
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"idx"
]):
if
len
(
start_traces
)
>
1
:
...
...
@@ -113,46 +116,46 @@ class SearchChunk(object):
):
continue
# check index source align
if
not
self
.
index_tracer
.
check_index_source
(
if
not
self
.
trace_flow
.
check_index_source
(
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
continue
# check index copmute
if
not
self
.
index_tracer
.
check_index_compute
(
if
not
self
.
trace_flow
.
check_index_compute
(
start_idx
,
end_dim
,
end_node
,
end_idx
):
continue
# flow search
chunk_info
=
self
.
index_tracer
.
flow_search
(
chunk_info
=
self
.
trace_flow
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
)
if
chunk_info
is
None
:
continue
# check index copmute
if
not
self
.
index_tracer
.
check_index_duplicate
(
chunk_info
):
if
not
self
.
trace_flow
.
check_index_duplicate
(
chunk_info
):
continue
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
,
peak_node
):
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
index_tracer
.
idx_trace_list
)
output_trace
=
copy
.
deepcopy
(
self
.
trace_index
.
idx_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
index_tracer
.
node_list
):
for
_
,
n
in
enumerate
(
self
.
trace_index
.
node_list
):
cur_trace
=
{}
for
arg
in
n
.
args
:
if
type
(
arg
)
==
type
(
n
)
and
not
is_non_compute_node_except_placeholder
(
arg
):
cur_trace
[
arg
]
=
self
.
index_tracer
.
_find_trace_from_node
(
arg
)
cur_trace
[
arg
]
=
self
.
trace_index
.
_find_trace_from_node
(
arg
)
input_trace
.
append
(
cur_trace
)
for
start_idx
in
range
(
max_chunk_region
[
0
],
peak_node
+
1
):
for
end_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
if
is_non_compute_node
(
self
.
index_tracer
.
node_list
[
start_idx
]
)
or
is_non_compute_node
(
self
.
index_tracer
.
node_list
[
end_idx
]):
self
.
trace_index
.
node_list
[
start_idx
]
)
or
is_non_compute_node
(
self
.
trace_index
.
node_list
[
end_idx
]):
continue
# select free dim
...
...
@@ -173,7 +176,7 @@ class SearchChunk(object):
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
best_chunk_region
=
self
.
chunk_selector
.
_select_best_chunk_region
(
best_chunk_region
=
self
.
select_chunk
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
,
peak_node
,
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
reorder_graph
.
reorder_all
(
best_chunk_region
)
...
...
@@ -191,8 +194,8 @@ class SearchChunk(object):
init_mem_peak
,
_
,
active_node
,
)
=
self
.
memory_
estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
)
=
self
.
estimat
e_mem
or
y
.
estimate_chunk_inference_mem
(
self
.
trace_index
.
node_list
)
mem_peak
=
init_mem_peak
...
...
@@ -206,14 +209,14 @@ class SearchChunk(object):
mem_peak
,
_
,
active_node
,
)
=
self
.
memory_
estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
chunk_infos
)
=
self
.
estimat
e_mem
or
y
.
estimate_chunk_inference_mem
(
self
.
trace_index
.
node_list
,
chunk_infos
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
if
self
.
print_mem
:
self
.
print_mem
=
False
self
.
memory_
estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
chunk_infos
,
print_mem
=
True
self
.
estimat
e_mem
or
y
.
estimate_chunk_inference_mem
(
self
.
trace_index
.
node_list
,
chunk_infos
,
print_mem
=
True
)
return
chunk_infos
colossalai/autochunk/select_chunk.py
View file @
a6cdbf91
from
.trace_index
import
TraceIndex
,
ReorderGraph
from
.trace_index
import
TraceIndex
from
.reorder_graph
import
ReorderGraph
from
.estiamte_memory
import
EstimateMemory
from
.utils
import
is_non_compute_node
...
...
colossalai/autochunk/trace_flow.py
0 → 100644
View file @
a6cdbf91
from
.trace_index
import
TraceIndex
from
.utils
import
(
find_chunk_all_input_nodes
,
find_chunk_compute_input_and_output_nodes
,
find_idx_by_name
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
)
class
TraceFlow
(
object
):
def
__init__
(
self
,
trace_index
:
TraceIndex
)
->
None
:
self
.
trace_index
=
trace_index
def
check_index_source
(
self
,
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
"""
Check 2 given index: one index should be source of the other
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
start_node_idx
=
find_idx_by_name
(
start_node
.
name
,
self
.
trace_index
.
node_list
)
end_node_trace
=
self
.
trace_index
.
_find_trace_from_node
(
end_node
)
end_node_trace_source
=
end_node_trace
[
"source"
][
end_dim
]
sorted_source
=
sorted
(
end_node_trace_source
.
items
(),
key
=
lambda
d
:
d
[
0
],
reverse
=
True
)
for
node_idx
,
node_dim
in
sorted_source
:
if
node_idx
==
start_node_idx
and
start_dim
in
node_dim
:
return
True
# it means we meet a node outside the loop, and the node is not input node
if
node_idx
<
start_idx
:
return
False
return
False
def
check_index_compute
(
self
,
start_idx
,
end_dim
,
end_node
,
end_idx
):
"""
Check 2 given index: check they haven't been computed in the source trace.
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
end_node_trace
=
self
.
trace_index
.
_find_trace_from_node
(
end_node
)
end_node_compute
=
end_node_trace
[
"compute"
][
end_dim
]
if
any
(
start_idx
<=
i
<=
end_idx
for
i
in
end_node_compute
):
return
False
return
True
def
get_node_chunk_dim
(
self
,
node_from
,
node_from_dim
,
node_to
):
node_from_source
=
self
.
trace_index
.
_find_source_trace_from_node
(
node_from
)
dim_source
=
node_from_source
[
node_from_dim
]
node_to_idx
=
find_idx_by_name
(
node_to
.
name
,
self
.
trace_index
.
node_list
)
for
k
,
v
in
dim_source
.
items
():
if
k
==
node_to_idx
:
return
v
return
None
def
_find_inherit_dim
(
self
,
input_node
,
input_dim
,
node
):
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_index
.
node_list
)
node_trace_source
=
self
.
trace_index
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
and
input_dim
[
0
]
in
node_trace_source
[
node_dim
][
input_node_idx
]
):
return
node_dim
return
None
def
check_index_duplicate
(
self
,
chunk_infos
,
return_dim
=
False
):
input_dim_after_node
=
{}
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
trace_index
.
node_list
[
k
])
if
inherit_dim
:
input_dim_after_node
[
k
]
=
inherit_dim
for
node
in
self
.
trace_index
.
node_list
[
chunk_infos
[
"region"
][
0
]
:
chunk_infos
[
"region"
][
1
]
+
1
]:
if
is_non_compute_node_except_placeholder
(
node
):
continue
count
=
0
duplicate_dims
=
[]
node_trace_source
=
self
.
trace_index
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
duplicate_dim
=
[]
duplicate_flag
=
False
dim_source
=
node_trace_source
[
node_dim
]
for
k
,
v
in
dim_source
.
items
():
if
chunk_infos
[
"region"
][
0
]
<=
k
<=
chunk_infos
[
"region"
][
1
]:
if
k
in
input_dim_after_node
and
input_dim_after_node
[
k
]
in
v
:
duplicate_flag
=
True
duplicate_dim
.
append
((
k
,
v
))
duplicate_dims
.
append
(
duplicate_dim
)
if
duplicate_flag
:
count
+=
1
if
count
>
1
:
if
return_dim
:
return
False
,
duplicate_dims
else
:
return
False
if
return_dim
:
return
True
,
None
else
:
return
True
def
_assgin_single_node_flow
(
self
,
arg_node
,
start_idx
,
end_idx
,
cur_node_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
,
):
arg_idx
=
find_idx_by_name
(
arg_node
.
name
,
self
.
trace_index
.
node_list
)
# arg in chunk range or be inputs
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
return
True
# find arg dim
if
cur_node_dim
is
not
None
:
# dim is computed
if
arg_idx
in
cur_node_compute
[
cur_node_dim
]:
return
False
if
arg_idx
not
in
cur_node_source
[
cur_node_dim
]:
arg_dim
=
None
else
:
arg_dim
=
cur_node_source
[
cur_node_dim
][
arg_idx
][
0
]
else
:
arg_dim
=
None
# get fix dim
arg_fix_dim
=
[]
if
cur_node_dim
is
not
None
:
for
i
in
cur_node_fix_dim
:
fix_dim_source
=
cur_node_source
[
i
]
if
arg_idx
in
fix_dim_source
:
arg_fix_dim
.
append
(
fix_dim_source
[
arg_idx
][
0
])
# if already in node_info, arg dim must be same
if
arg_node
in
all_node_info
:
if
all_node_info
[
arg_node
][
"chunk_dim"
]
!=
arg_dim
:
return
False
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
)
)
# else add it to list
else
:
all_node_info
[
arg_node
]
=
{
"chunk_dim"
:
arg_dim
,
"fix_dim"
:
arg_fix_dim
}
next_node_list
.
append
(
arg_node
)
return
True
def
_get_all_node_info
(
self
,
end_dim
,
start_idx
,
end_idx
):
cur_node_list
=
[
self
.
trace_index
.
node_list
[
end_idx
]
]
# start from the last node
all_node_info
=
{
cur_node_list
[
0
]:
{
"chunk_dim"
:
end_dim
,
"fix_dim"
:
[]}}
while
len
(
cur_node_list
)
>
0
:
next_node_list
=
[]
for
cur_node
in
cur_node_list
:
# get cur node info
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
if
cur_node_chunk_dim
:
cur_node_compute
=
self
.
trace_index
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
trace_index
.
_find_source_trace_from_node
(
cur_node
)
else
:
cur_node_compute
=
cur_node_source
=
None
# get all valid args
arg_list
=
[]
for
arg
in
cur_node
.
args
:
if
type
(
arg
)
!=
type
(
cur_node
):
continue
if
is_non_compute_node
(
arg
):
continue
arg_list
.
append
(
arg
)
flow_flag
=
self
.
_assgin_single_node_flow
(
arg
,
start_idx
,
end_idx
,
cur_node_chunk_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
,
)
if
flow_flag
==
False
:
return
None
if
len
(
arg_list
)
==
2
:
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
]):
for
arg
in
arg_list
:
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_index
.
node_list
)
<
end_idx
):
continue
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
arg_fix_dim
=
all_node_info
[
arg
][
"fix_dim"
]
arg_shape
=
get_node_shape
(
arg
)
# add all dim as fix dim except chunk dim
for
i
,
shape
in
enumerate
(
arg_shape
):
if
shape
!=
1
and
i
!=
cur_node_chunk_dim
:
if
i
==
arg_chunk_dim
:
return
None
if
i
not
in
arg_fix_dim
:
arg_fix_dim
.
append
(
i
)
elif
"einsum"
in
cur_node
.
name
:
pass
elif
"matmul"
in
cur_node
.
name
:
pass
else
:
raise
NotImplementedError
()
cur_node_list
=
next_node_list
return
all_node_info
def
_get_input_nodes_dim
(
self
,
inputs
,
start_idx
,
end_idx
,
all_node_info
):
inputs_dim
=
[]
remove_inputs
=
[]
for
input_node
in
inputs
:
input_dict
=
{}
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_index
.
node_list
)
for
user
in
input_node
.
users
.
keys
():
if
is_non_compute_node
(
user
):
continue
user_idx
=
find_idx_by_name
(
user
.
name
,
self
.
trace_index
.
node_list
)
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
if
chunk_dim
is
not
None
:
user_source
=
self
.
trace_index
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
if
input_node_idx
in
user_source
:
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
else
:
return
None
,
None
if
len
(
input_dict
)
==
0
:
remove_inputs
.
append
(
input_node
)
else
:
inputs_dim
.
append
(
input_dict
)
for
i
in
remove_inputs
:
if
i
in
inputs
:
inputs
.
remove
(
i
)
return
inputs
,
inputs_dim
def
_get_prepose_nodes
(
self
,
all_node_info
,
start_idx
,
end_idx
):
# get all possible prepose nodes
maybe_prepose_nodes
=
[]
for
node
,
node_info
in
all_node_info
.
items
():
if
node_info
[
"chunk_dim"
]
is
None
:
maybe_prepose_nodes
.
append
(
node
)
maybe_prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_index
.
node_list
),
reverse
=
True
,
)
# from last node to first node
prepose_nodes
=
[]
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while
len
(
maybe_prepose_nodes
)
>
0
:
tmp_cur_prepose_nodes
=
[
maybe_prepose_nodes
[
0
]]
tmp_cur_related_prepose_nodes
=
[]
prepose_flag
=
True
# loop cur node's all arg until out of chunk
while
len
(
tmp_cur_prepose_nodes
)
>
0
:
if
prepose_flag
==
False
:
break
tmp_next_prepose_nodes
=
[]
tmp_cur_related_prepose_nodes
.
extend
(
tmp_cur_prepose_nodes
)
for
cur_prepose_node
in
tmp_cur_prepose_nodes
:
if
prepose_flag
==
False
:
break
for
cur_prepose_node_arg
in
cur_prepose_node
.
args
:
if
type
(
cur_prepose_node_arg
)
!=
type
(
cur_prepose_node
):
continue
# out of loop
if
not
(
start_idx
<=
find_idx_by_name
(
cur_prepose_node_arg
.
name
,
self
.
trace_index
.
node_list
)
<
end_idx
):
continue
# compute op in loop
elif
cur_prepose_node_arg
in
all_node_info
:
if
all_node_info
[
cur_prepose_node_arg
][
"chunk_dim"
]
is
None
:
tmp_next_prepose_nodes
.
append
(
cur_prepose_node_arg
)
else
:
prepose_flag
=
False
break
# non compute op
else
:
tmp_next_prepose_nodes
.
append
(
cur_prepose_node_arg
)
tmp_cur_prepose_nodes
=
tmp_next_prepose_nodes
if
prepose_flag
==
False
:
maybe_prepose_nodes
.
remove
(
maybe_prepose_nodes
[
0
])
continue
else
:
for
n
in
tmp_cur_related_prepose_nodes
:
if
n
not
in
prepose_nodes
:
prepose_nodes
.
append
(
n
)
if
n
in
maybe_prepose_nodes
:
maybe_prepose_nodes
.
remove
(
n
)
# sort by index
prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_index
.
node_list
)
)
return
prepose_nodes
def
_get_non_chunk_inputs
(
self
,
chunk_info
,
start_idx
,
end_idx
):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list
=
self
.
trace_index
.
node_list
[
start_idx
:
end_idx
+
1
]
# also need to get some prepose node's arg out of non_chunk_inputs
for
n
in
chunk_info
[
"args"
][
"prepose_nodes"
]:
chunk_node_list
.
remove
(
n
)
non_chunk_inputs
=
find_chunk_all_input_nodes
(
chunk_node_list
)
for
i
in
non_chunk_inputs
:
if
i
not
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
return
chunk_info
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
trace_index
.
node_list
[
start_idx
:
end_idx
+
1
]
)
# only single ouput
if
len
(
outputs
)
>
1
:
return
None
# get every node's chunk dim and fix dim
all_node_info
=
self
.
_get_all_node_info
(
end_dim
,
start_idx
,
end_idx
)
if
all_node_info
is
None
:
return
None
# get input nodes' chunk dim
inputs
,
inputs_dim
=
self
.
_get_input_nodes_dim
(
inputs
,
start_idx
,
end_idx
,
all_node_info
)
if
inputs
is
None
:
return
None
chunk_info
=
{
"region"
:
(
start_idx
,
end_idx
),
"inputs"
:
inputs
,
"inputs_non_chunk"
:
[],
"inputs_dim"
:
inputs_dim
,
"outputs"
:
outputs
,
"outputs_dim"
:
end_dim
,
"node_chunk_dim"
:
all_node_info
,
"args"
:
{},
}
# move useless nodes ahead of loop
chunk_info
[
"args"
][
"prepose_nodes"
]
=
self
.
_get_prepose_nodes
(
all_node_info
,
start_idx
,
end_idx
)
# find non chunk inputs
chunk_info
=
self
.
_get_non_chunk_inputs
(
chunk_info
,
start_idx
,
end_idx
)
# reassgin reshape size, some size may have changed due to chunk
chunk_info
=
self
.
_reassgin_reshape_size
(
chunk_info
)
return
chunk_info
def
_reassgin_reshape_size
(
self
,
chunk_info
):
chunk_region
=
chunk_info
[
"region"
]
reshape_size
=
{}
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]
]
for
node
in
self
.
trace_index
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
trace_index
.
idx_view_list
[
node
]
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
reshape_size
[
node
.
name
]
=
{}
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
if
reshape_arg_dim
in
reshape_log
[
"dim_to"
]:
continue
if
reshape_arg_dim
==
chunk_dim
:
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
(
"min(chunk_size, %d - chunk_idx)"
%
chunk_shape
)
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
colossalai/autochunk/trace_index.py
View file @
a6cdbf91
import
copy
from
.utils
import
(
find_chunk_all_input_nodes
,
find_chunk_compute_input_and_output_nodes
,
find_idx_by_name
,
get_node_shape
,
is_non_compute_node
,
is_non_compute_node_except_placeholder
,
)
...
...
@@ -588,394 +584,3 @@ class TraceIndex(object):
continue
else
:
raise
NotImplementedError
(
node
.
op
,
"op not implemented yet!"
)
# self._merge_equal_idx()
def
check_index_source
(
self
,
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
"""
Check 2 given index: one index should be source of the other
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
start_node_idx
=
find_idx_by_name
(
start_node
.
name
,
self
.
node_list
)
end_node_trace
=
self
.
_find_trace_from_node
(
end_node
)
end_node_trace_source
=
end_node_trace
[
"source"
][
end_dim
]
sorted_source
=
sorted
(
end_node_trace_source
.
items
(),
key
=
lambda
d
:
d
[
0
],
reverse
=
True
)
for
node_idx
,
node_dim
in
sorted_source
:
if
node_idx
==
start_node_idx
and
start_dim
in
node_dim
:
return
True
# it means we meet a node outside the loop, and the node is not input node
if
node_idx
<
start_idx
:
return
False
return
False
def
check_index_compute
(
self
,
start_idx
,
end_dim
,
end_node
,
end_idx
):
"""
Check 2 given index: check they haven't been computed in the source trace.
Args:
start_idx(int): start node chunk dim
start_node(node): start node
end_idx(int): end node chunk dim
end_node(node): end node
Returns:
bool: True if check pass
"""
end_node_trace
=
self
.
_find_trace_from_node
(
end_node
)
end_node_compute
=
end_node_trace
[
"compute"
][
end_dim
]
if
any
(
start_idx
<=
i
<=
end_idx
for
i
in
end_node_compute
):
return
False
return
True
def
get_node_chunk_dim
(
self
,
node_from
,
node_from_dim
,
node_to
):
node_from_source
=
self
.
_find_source_trace_from_node
(
node_from
)
dim_source
=
node_from_source
[
node_from_dim
]
node_to_idx
=
find_idx_by_name
(
node_to
.
name
,
self
.
node_list
)
for
k
,
v
in
dim_source
.
items
():
if
k
==
node_to_idx
:
return
v
return
None
def
_find_inherit_dim
(
self
,
input_node
,
input_dim
,
node
):
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
node_trace_source
=
self
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
and
input_dim
[
0
]
in
node_trace_source
[
node_dim
][
input_node_idx
]
):
return
node_dim
return
None
def
check_index_duplicate
(
self
,
chunk_infos
,
return_dim
=
False
):
input_dim_after_node
=
{}
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
node_list
[
k
])
if
inherit_dim
:
input_dim_after_node
[
k
]
=
inherit_dim
for
node
in
self
.
node_list
[
chunk_infos
[
"region"
][
0
]
:
chunk_infos
[
"region"
][
1
]
+
1
]:
if
is_non_compute_node_except_placeholder
(
node
):
continue
count
=
0
duplicate_dims
=
[]
node_trace_source
=
self
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
duplicate_dim
=
[]
duplicate_flag
=
False
dim_source
=
node_trace_source
[
node_dim
]
for
k
,
v
in
dim_source
.
items
():
if
chunk_infos
[
"region"
][
0
]
<=
k
<=
chunk_infos
[
"region"
][
1
]:
if
k
in
input_dim_after_node
and
input_dim_after_node
[
k
]
in
v
:
duplicate_flag
=
True
duplicate_dim
.
append
((
k
,
v
))
duplicate_dims
.
append
(
duplicate_dim
)
if
duplicate_flag
:
count
+=
1
if
count
>
1
:
if
return_dim
:
return
False
,
duplicate_dims
else
:
return
False
if
return_dim
:
return
True
,
None
else
:
return
True
def
_assgin_single_node_flow
(
self
,
arg_node
,
start_idx
,
end_idx
,
cur_node_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
,
):
arg_idx
=
find_idx_by_name
(
arg_node
.
name
,
self
.
node_list
)
# arg in chunk range or be inputs
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
return
True
# find arg dim
if
cur_node_dim
is
not
None
:
# dim is computed
if
arg_idx
in
cur_node_compute
[
cur_node_dim
]:
return
False
if
arg_idx
not
in
cur_node_source
[
cur_node_dim
]:
arg_dim
=
None
else
:
arg_dim
=
cur_node_source
[
cur_node_dim
][
arg_idx
][
0
]
else
:
arg_dim
=
None
# get fix dim
arg_fix_dim
=
[]
if
cur_node_dim
is
not
None
:
for
i
in
cur_node_fix_dim
:
fix_dim_source
=
cur_node_source
[
i
]
if
arg_idx
in
fix_dim_source
:
arg_fix_dim
.
append
(
fix_dim_source
[
arg_idx
][
0
])
# if already in node_info, arg dim must be same
if
arg_node
in
all_node_info
:
if
all_node_info
[
arg_node
][
"chunk_dim"
]
!=
arg_dim
:
return
False
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
)
)
# else add it to list
else
:
all_node_info
[
arg_node
]
=
{
"chunk_dim"
:
arg_dim
,
"fix_dim"
:
arg_fix_dim
}
next_node_list
.
append
(
arg_node
)
return
True
def
_get_all_node_info
(
self
,
end_dim
,
start_idx
,
end_idx
):
cur_node_list
=
[
self
.
node_list
[
end_idx
]]
# start from the last node
all_node_info
=
{
cur_node_list
[
0
]:
{
"chunk_dim"
:
end_dim
,
"fix_dim"
:
[]}}
while
len
(
cur_node_list
)
>
0
:
next_node_list
=
[]
for
cur_node
in
cur_node_list
:
# get cur node info
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
if
cur_node_chunk_dim
:
cur_node_compute
=
self
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
_find_source_trace_from_node
(
cur_node
)
else
:
cur_node_compute
=
cur_node_source
=
None
# get all valid args
arg_list
=
[]
for
arg
in
cur_node
.
args
:
if
type
(
arg
)
!=
type
(
cur_node
):
continue
if
is_non_compute_node
(
arg
):
continue
arg_list
.
append
(
arg
)
flow_flag
=
self
.
_assgin_single_node_flow
(
arg
,
start_idx
,
end_idx
,
cur_node_chunk_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
,
)
if
flow_flag
==
False
:
return
None
if
len
(
arg_list
)
==
2
:
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
]):
for
arg
in
arg_list
:
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
node_list
)
<
end_idx
):
continue
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
arg_fix_dim
=
all_node_info
[
arg
][
"fix_dim"
]
arg_shape
=
get_node_shape
(
arg
)
# add all dim as fix dim except chunk dim
for
i
,
shape
in
enumerate
(
arg_shape
):
if
shape
!=
1
and
i
!=
cur_node_chunk_dim
:
if
i
==
arg_chunk_dim
:
return
None
if
i
not
in
arg_fix_dim
:
arg_fix_dim
.
append
(
i
)
elif
"einsum"
in
cur_node
.
name
:
pass
elif
"matmul"
in
cur_node
.
name
:
pass
else
:
raise
NotImplementedError
()
cur_node_list
=
next_node_list
return
all_node_info
def
_get_input_nodes_dim
(
self
,
inputs
,
start_idx
,
end_idx
,
all_node_info
):
inputs_dim
=
[]
remove_inputs
=
[]
for
input_node
in
inputs
:
input_dict
=
{}
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
for
user
in
input_node
.
users
.
keys
():
if
is_non_compute_node
(
user
):
continue
user_idx
=
find_idx_by_name
(
user
.
name
,
self
.
node_list
)
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
if
chunk_dim
is
not
None
:
user_source
=
self
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
if
input_node_idx
in
user_source
:
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
else
:
return
None
,
None
if
len
(
input_dict
)
==
0
:
remove_inputs
.
append
(
input_node
)
else
:
inputs_dim
.
append
(
input_dict
)
for
i
in
remove_inputs
:
if
i
in
inputs
:
inputs
.
remove
(
i
)
return
inputs
,
inputs_dim
def
_get_prepose_nodes
(
self
,
all_node_info
,
start_idx
,
end_idx
):
# get all possible prepose nodes
maybe_prepose_nodes
=
[]
for
node
,
node_info
in
all_node_info
.
items
():
if
node_info
[
"chunk_dim"
]
is
None
:
maybe_prepose_nodes
.
append
(
node
)
maybe_prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
node_list
),
reverse
=
True
,
)
# from last node to first node
prepose_nodes
=
[]
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while
len
(
maybe_prepose_nodes
)
>
0
:
tmp_cur_prepose_nodes
=
[
maybe_prepose_nodes
[
0
]]
tmp_cur_related_prepose_nodes
=
[]
prepose_flag
=
True
# loop cur node's all arg until out of chunk
while
len
(
tmp_cur_prepose_nodes
)
>
0
:
if
prepose_flag
==
False
:
break
tmp_next_prepose_nodes
=
[]
tmp_cur_related_prepose_nodes
.
extend
(
tmp_cur_prepose_nodes
)
for
cur_prepose_node
in
tmp_cur_prepose_nodes
:
if
prepose_flag
==
False
:
break
for
cur_prepose_node_arg
in
cur_prepose_node
.
args
:
if
type
(
cur_prepose_node_arg
)
!=
type
(
cur_prepose_node
):
continue
# out of loop
if
not
(
start_idx
<=
find_idx_by_name
(
cur_prepose_node_arg
.
name
,
self
.
node_list
)
<
end_idx
):
continue
# compute op in loop
elif
cur_prepose_node_arg
in
all_node_info
:
if
all_node_info
[
cur_prepose_node_arg
][
"chunk_dim"
]
is
None
:
tmp_next_prepose_nodes
.
append
(
cur_prepose_node_arg
)
else
:
prepose_flag
=
False
break
# non compute op
else
:
tmp_next_prepose_nodes
.
append
(
cur_prepose_node_arg
)
tmp_cur_prepose_nodes
=
tmp_next_prepose_nodes
if
prepose_flag
==
False
:
maybe_prepose_nodes
.
remove
(
maybe_prepose_nodes
[
0
])
continue
else
:
for
n
in
tmp_cur_related_prepose_nodes
:
if
n
not
in
prepose_nodes
:
prepose_nodes
.
append
(
n
)
if
n
in
maybe_prepose_nodes
:
maybe_prepose_nodes
.
remove
(
n
)
# sort by index
prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
node_list
))
return
prepose_nodes
def
_get_non_chunk_inputs
(
self
,
chunk_info
,
start_idx
,
end_idx
):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list
=
self
.
node_list
[
start_idx
:
end_idx
+
1
]
# also need to get some prepose node's arg out of non_chunk_inputs
for
n
in
chunk_info
[
"args"
][
"prepose_nodes"
]:
chunk_node_list
.
remove
(
n
)
non_chunk_inputs
=
find_chunk_all_input_nodes
(
chunk_node_list
)
for
i
in
non_chunk_inputs
:
if
i
not
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
return
chunk_info
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
# only single ouput
if
len
(
outputs
)
>
1
:
return
None
# get every node's chunk dim and fix dim
all_node_info
=
self
.
_get_all_node_info
(
end_dim
,
start_idx
,
end_idx
)
if
all_node_info
is
None
:
return
None
# get input nodes' chunk dim
inputs
,
inputs_dim
=
self
.
_get_input_nodes_dim
(
inputs
,
start_idx
,
end_idx
,
all_node_info
)
if
inputs
is
None
:
return
None
chunk_info
=
{
"region"
:
(
start_idx
,
end_idx
),
"inputs"
:
inputs
,
"inputs_non_chunk"
:
[],
"inputs_dim"
:
inputs_dim
,
"outputs"
:
outputs
,
"outputs_dim"
:
end_dim
,
"node_chunk_dim"
:
all_node_info
,
"args"
:
{},
}
# move useless nodes ahead of loop
chunk_info
[
"args"
][
"prepose_nodes"
]
=
self
.
_get_prepose_nodes
(
all_node_info
,
start_idx
,
end_idx
)
# find non chunk inputs
chunk_info
=
self
.
_get_non_chunk_inputs
(
chunk_info
,
start_idx
,
end_idx
)
# reassgin reshape size, some size may have changed due to chunk
chunk_info
=
self
.
_reassgin_reshape_size
(
chunk_info
)
return
chunk_info
def
_reassgin_reshape_size
(
self
,
chunk_info
):
chunk_region
=
chunk_info
[
"region"
]
reshape_size
=
{}
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]
]
for
node
in
self
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
idx_view_list
[
node
]
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
reshape_size
[
node
.
name
]
=
{}
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
if
reshape_arg_dim
in
reshape_log
[
"dim_to"
]:
continue
if
reshape_arg_dim
==
chunk_dim
:
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
(
"min(chunk_size, %d - chunk_idx)"
%
chunk_shape
)
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
tests/test_autochunk/benchmark_autochunk.py
View file @
a6cdbf91
...
...
@@ -104,8 +104,8 @@ def benchmark_evoformer():
model
=
evoformer_base
().
cuda
()
# build autochunk model
#
max_memory = 1000 # MB fit memory mode
max_memory
=
None
# min memory mode
max_memory
=
1000
# MB fit memory mode
#
max_memory = None # min memory mode
autochunk
=
_build_autochunk
(
evoformer_base
().
cuda
(),
max_memory
,
node
,
pair
)
# build openfold
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment