Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0ea903b9
Commit
0ea903b9
authored
Jan 09, 2023
by
oahzxl
Browse files
rename trace_index to trace_indice
parent
065f0b4c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
74 additions
and
74 deletions
+74
-74
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+2
-2
colossalai/autochunk/reorder_graph.py
colossalai/autochunk/reorder_graph.py
+16
-16
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+16
-16
colossalai/autochunk/select_chunk.py
colossalai/autochunk/select_chunk.py
+11
-11
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+28
-28
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+1
-1
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
0ea903b9
...
...
@@ -94,9 +94,9 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
return
context
def
_replace_ones_like
(
search_chunk
,
chunk_infos
,
region_idx
,
node_idx
,
node
,
body
):
def
_replace_ones_like
(
search_chunk
:
SearchChunk
,
chunk_infos
,
region_idx
,
node_idx
,
node
,
body
):
if
"ones_like"
in
node
.
name
:
meta_node
=
search_chunk
.
trace_inde
x
.
node_list
[
node_idx
]
meta_node
=
search_chunk
.
trace_ind
ic
e
.
node_list
[
node_idx
]
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
if
get_node_shape
(
meta_node
)[
chunk_dim
]
!=
1
:
source_node
=
meta_node
.
args
[
0
].
args
[
0
]
...
...
colossalai/autochunk/reorder_graph.py
View file @
0ea903b9
from
.trace_inde
x
import
TraceInde
x
from
.trace_ind
ic
e
import
TraceInd
ic
e
from
.utils
import
find_idx_by_name
class
ReorderGraph
(
object
):
def
__init__
(
self
,
trace_inde
x
:
TraceInde
x
)
->
None
:
self
.
trace_inde
x
=
trace_inde
x
def
__init__
(
self
,
trace_ind
ic
e
:
TraceInd
ic
e
)
->
None
:
self
.
trace_ind
ic
e
=
trace_ind
ic
e
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
trace_inde
x
.
idx_trace_list
))
i
:
i
for
i
in
range
(
len
(
self
.
trace_ind
ic
e
.
idx_trace_list
))
}
def
_get_reorder_map
(
self
,
chunk_info
):
reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
trace_inde
x
.
node_list
))}
reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
trace_ind
ic
e
.
node_list
))}
chunk_region_start
=
chunk_info
[
"region"
][
0
]
chunk_region_end
=
chunk_info
[
"region"
][
1
]
chunk_prepose_nodes
=
chunk_info
[
"args"
][
"prepose_nodes"
]
chunk_prepose_nodes_idx
=
[
find_idx_by_name
(
i
.
name
,
self
.
trace_inde
x
.
node_list
)
find_idx_by_name
(
i
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
for
i
in
chunk_prepose_nodes
]
# put prepose nodes ahead
...
...
@@ -24,10 +24,10 @@ class ReorderGraph(object):
n_idx
=
chunk_prepose_nodes_idx
[
idx
]
reorder_map
[
n_idx
]
=
chunk_region_start
+
idx
# put other nodes after prepose nodes
for
n
in
self
.
trace_inde
x
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
for
n
in
self
.
trace_ind
ic
e
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
if
n
in
chunk_prepose_nodes
:
continue
n_idx
=
find_idx_by_name
(
n
.
name
,
self
.
trace_inde
x
.
node_list
)
n_idx
=
find_idx_by_name
(
n
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
pos
=
sum
([
n_idx
<
i
for
i
in
chunk_prepose_nodes_idx
])
reorder_map
[
n_idx
]
=
n_idx
+
pos
...
...
@@ -53,25 +53,25 @@ class ReorderGraph(object):
self
.
all_reorder_map
[
origin_idx
]
=
reorder_map
[
map_idx
]
def
_reorder_self_node_list
(
self
,
reorder_map
):
new_node_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_inde
x
.
node_list
))]
new_node_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_ind
ic
e
.
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
self
.
trace_inde
x
.
node_list
[
old_idx
]
self
.
trace_inde
x
.
node_list
=
new_node_list
new_node_list
[
new_idx
]
=
self
.
trace_ind
ic
e
.
node_list
[
old_idx
]
self
.
trace_ind
ic
e
.
node_list
=
new_node_list
def
_reorder_idx_trace
(
self
,
reorder_map
):
# reorder list
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_inde
x
.
idx_trace_list
))]
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_ind
ic
e
.
idx_trace_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_idx_trace_list
[
new_idx
]
=
self
.
trace_inde
x
.
idx_trace_list
[
old_idx
]
self
.
trace_inde
x
.
idx_trace_list
=
new_idx_trace_list
new_idx_trace_list
[
new_idx
]
=
self
.
trace_ind
ic
e
.
idx_trace_list
[
old_idx
]
self
.
trace_ind
ic
e
.
idx_trace_list
=
new_idx_trace_list
# update compute
for
idx_trace
in
self
.
trace_inde
x
.
idx_trace_list
:
for
idx_trace
in
self
.
trace_ind
ic
e
.
idx_trace_list
:
compute
=
idx_trace
[
"compute"
]
for
dim_compute
in
compute
:
for
idx
,
i
in
enumerate
(
dim_compute
):
dim_compute
[
idx
]
=
reorder_map
[
i
]
# update source
for
idx_trace
in
self
.
trace_inde
x
.
idx_trace_list
:
for
idx_trace
in
self
.
trace_ind
ic
e
.
idx_trace_list
:
source
=
idx_trace
[
"source"
]
for
dim_idx
,
dim_source
in
enumerate
(
source
):
new_dim_source
=
{}
...
...
colossalai/autochunk/search_chunk.py
View file @
0ea903b9
...
...
@@ -7,7 +7,7 @@ from .estimate_memory import EstimateMemory
from
.reorder_graph
import
ReorderGraph
from
.select_chunk
import
SelectChunk
from
.trace_flow
import
TraceFlow
from
.trace_inde
x
import
TraceInde
x
from
.trace_ind
ic
e
import
TraceInd
ic
e
from
.utils
import
(
get_node_shape
,
is_non_compute_node
,
...
...
@@ -47,13 +47,13 @@ class SearchChunk(object):
def
__init__
(
self
,
gm
,
max_memory
=
None
,
print_mem
=
False
)
->
None
:
self
.
gm
=
gm
self
.
print_mem
=
print_mem
self
.
trace_inde
x
=
TraceInde
x
(
list
(
gm
.
graph
.
nodes
))
self
.
trace_inde
x
.
trace_index
()
self
.
trace_flow
=
TraceFlow
(
self
.
trace_inde
x
)
self
.
reorder_graph
=
ReorderGraph
(
self
.
trace_inde
x
)
self
.
trace_ind
ic
e
=
TraceInd
ic
e
(
list
(
gm
.
graph
.
nodes
))
self
.
trace_ind
ic
e
.
trace_index
()
self
.
trace_flow
=
TraceFlow
(
self
.
trace_ind
ic
e
)
self
.
reorder_graph
=
ReorderGraph
(
self
.
trace_ind
ic
e
)
self
.
estimate_memory
=
EstimateMemory
()
self
.
select_chunk
=
SelectChunk
(
self
.
trace_inde
x
,
self
.
trace_ind
ic
e
,
self
.
estimate_memory
,
self
.
reorder_graph
,
max_memory
=
max_memory
,
...
...
@@ -72,7 +72,7 @@ class SearchChunk(object):
free_var_idx (List): all indexs of free vars
"""
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
trace_inde
x
.
node_list
):
for
idx
,
n
in
enumerate
(
self
.
trace_ind
ic
e
.
node_list
):
if
n
.
op
==
"placeholder"
:
free_var_idx
.
append
(
idx
)
return
free_var_idx
...
...
@@ -156,7 +156,7 @@ class SearchChunk(object):
"""
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
trace_inde
x
.
node_list
[
end_idx
]
end_node
=
self
.
trace_ind
ic
e
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"idx"
]):
if
len
(
start_traces
)
>
1
:
...
...
@@ -205,23 +205,23 @@ class SearchChunk(object):
possible_chunk_region (List)
"""
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
trace_inde
x
.
idx_trace_list
)
output_trace
=
copy
.
deepcopy
(
self
.
trace_ind
ic
e
.
idx_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
trace_inde
x
.
node_list
):
for
_
,
n
in
enumerate
(
self
.
trace_ind
ic
e
.
node_list
):
cur_trace
=
{}
for
arg
in
n
.
args
:
if
type
(
arg
)
==
type
(
n
)
and
not
is_non_compute_node_except_placeholder
(
arg
):
cur_trace
[
arg
]
=
self
.
trace_inde
x
.
_find_trace_from_node
(
arg
)
cur_trace
[
arg
]
=
self
.
trace_ind
ic
e
.
_find_trace_from_node
(
arg
)
input_trace
.
append
(
cur_trace
)
for
start_idx
in
range
(
max_chunk_region
[
0
],
peak_node
+
1
):
for
end_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
if
is_non_compute_node
(
self
.
trace_inde
x
.
node_list
[
start_idx
]
)
or
is_non_compute_node
(
self
.
trace_inde
x
.
node_list
[
end_idx
]):
self
.
trace_ind
ic
e
.
node_list
[
start_idx
]
)
or
is_non_compute_node
(
self
.
trace_ind
ic
e
.
node_list
[
end_idx
]):
continue
# select free dim
...
...
@@ -292,7 +292,7 @@ class SearchChunk(object):
_
,
active_node
,
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_inde
x
.
node_list
self
.
trace_ind
ic
e
.
node_list
)
mem_peak
=
init_mem_peak
...
...
@@ -307,13 +307,13 @@ class SearchChunk(object):
_
,
active_node
,
)
=
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_inde
x
.
node_list
,
chunk_infos
self
.
trace_ind
ic
e
.
node_list
,
chunk_infos
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
if
self
.
print_mem
:
self
.
print_mem
=
False
self
.
estimate_memory
.
estimate_chunk_inference_mem
(
self
.
trace_inde
x
.
node_list
,
chunk_infos
,
print_mem
=
True
self
.
trace_ind
ic
e
.
node_list
,
chunk_infos
,
print_mem
=
True
)
return
chunk_infos
colossalai/autochunk/select_chunk.py
View file @
0ea903b9
from
.estimate_memory
import
EstimateMemory
from
.reorder_graph
import
ReorderGraph
from
.trace_inde
x
import
TraceInde
x
from
.trace_ind
ic
e
import
TraceInd
ic
e
from
.utils
import
is_non_compute_node
class
SelectChunk
(
object
):
def
__init__
(
self
,
trace_inde
x
:
TraceInde
x
,
trace_ind
ic
e
:
TraceInd
ic
e
,
estimate_memory
:
EstimateMemory
,
reorder_graph
:
ReorderGraph
,
max_memory
=
None
,
):
self
.
index_tra
ce
r
=
trace_inde
x
self
.
memory_
estimator
=
estimate_memory
self
.
trace_indi
ce
=
trace_ind
ic
e
self
.
estimat
e_mem
or
y
=
estimate_memory
self
.
reorder_graph
=
reorder_graph
if
max_memory
is
not
None
:
self
.
stratge
=
"fit_memory"
...
...
@@ -68,10 +68,10 @@ class SelectChunk(object):
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
self
.
trace_indice
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
memory_
estimator
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
estimat
e_mem
or
y
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
...
...
@@ -113,7 +113,7 @@ class SelectChunk(object):
chunk_size
*=
2
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_infos
=
chunk_infos
+
[
reorder_chunk_info
]
cur_mem_peak
=
self
.
memory_
estimator
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
estimat
e_mem
or
y
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
...
...
@@ -139,7 +139,7 @@ class SelectChunk(object):
mid
=
int
((
left
+
right
)
/
2
+
0.5
)
chunk_info
[
"chunk_size"
]
=
mid
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_mem_peak
=
self
.
memory_
estimator
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
estimat
e_mem
or
y
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
"reorder_node_list"
],
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
...
...
@@ -153,7 +153,7 @@ class SelectChunk(object):
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
for
i
in
self
.
index_tracer
.
node_list
[
start
:
end
+
1
]:
for
i
in
self
.
trace_indice
.
node_list
[
start
:
end
+
1
]:
if
not
is_non_compute_node
(
i
):
count
+=
1
return
count
...
...
@@ -178,10 +178,10 @@ class SelectChunk(object):
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
self
.
trace_indice
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
memory_
estimator
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
estimat
e_mem
or
y
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
...
...
colossalai/autochunk/trace_flow.py
View file @
0ea903b9
from
.trace_inde
x
import
TraceInde
x
from
.trace_ind
ic
e
import
TraceInd
ic
e
from
.utils
import
(
find_chunk_all_input_nodes
,
find_chunk_compute_input_and_output_nodes
,
...
...
@@ -10,8 +10,8 @@ from .utils import (
class
TraceFlow
(
object
):
def
__init__
(
self
,
trace_inde
x
:
TraceInde
x
)
->
None
:
self
.
trace_inde
x
=
trace_inde
x
def
__init__
(
self
,
trace_ind
ic
e
:
TraceInd
ic
e
)
->
None
:
self
.
trace_ind
ic
e
=
trace_ind
ic
e
def
check_index_source
(
self
,
start_dim
,
start_node
,
start_idx
,
end_dim
,
end_node
):
"""
...
...
@@ -25,8 +25,8 @@ class TraceFlow(object):
Returns:
bool: True if check pass
"""
start_node_idx
=
find_idx_by_name
(
start_node
.
name
,
self
.
trace_inde
x
.
node_list
)
end_node_trace
=
self
.
trace_inde
x
.
_find_trace_from_node
(
end_node
)
start_node_idx
=
find_idx_by_name
(
start_node
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
end_node_trace
=
self
.
trace_ind
ic
e
.
_find_trace_from_node
(
end_node
)
end_node_trace_source
=
end_node_trace
[
"source"
][
end_dim
]
sorted_source
=
sorted
(
end_node_trace_source
.
items
(),
key
=
lambda
d
:
d
[
0
],
reverse
=
True
...
...
@@ -51,24 +51,24 @@ class TraceFlow(object):
Returns:
bool: True if check pass
"""
end_node_trace
=
self
.
trace_inde
x
.
_find_trace_from_node
(
end_node
)
end_node_trace
=
self
.
trace_ind
ic
e
.
_find_trace_from_node
(
end_node
)
end_node_compute
=
end_node_trace
[
"compute"
][
end_dim
]
if
any
(
start_idx
<=
i
<=
end_idx
for
i
in
end_node_compute
):
return
False
return
True
def
get_node_chunk_dim
(
self
,
node_from
,
node_from_dim
,
node_to
):
node_from_source
=
self
.
trace_inde
x
.
_find_source_trace_from_node
(
node_from
)
node_from_source
=
self
.
trace_ind
ic
e
.
_find_source_trace_from_node
(
node_from
)
dim_source
=
node_from_source
[
node_from_dim
]
node_to_idx
=
find_idx_by_name
(
node_to
.
name
,
self
.
trace_inde
x
.
node_list
)
node_to_idx
=
find_idx_by_name
(
node_to
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
for
k
,
v
in
dim_source
.
items
():
if
k
==
node_to_idx
:
return
v
return
None
def
_find_inherit_dim
(
self
,
input_node
,
input_dim
,
node
):
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_inde
x
.
node_list
)
node_trace_source
=
self
.
trace_inde
x
.
_find_source_trace_from_node
(
node
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
node_trace_source
=
self
.
trace_ind
ic
e
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
...
...
@@ -82,19 +82,19 @@ class TraceFlow(object):
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
trace_inde
x
.
node_list
[
k
]
input_node
,
v
,
self
.
trace_ind
ic
e
.
node_list
[
k
]
)
if
inherit_dim
:
input_dim_after_node
[
k
]
=
inherit_dim
for
node
in
self
.
trace_inde
x
.
node_list
[
for
node
in
self
.
trace_ind
ic
e
.
node_list
[
chunk_infos
[
"region"
][
0
]
:
chunk_infos
[
"region"
][
1
]
+
1
]:
if
is_non_compute_node_except_placeholder
(
node
):
continue
count
=
0
duplicate_dims
=
[]
node_trace_source
=
self
.
trace_inde
x
.
_find_source_trace_from_node
(
node
)
node_trace_source
=
self
.
trace_ind
ic
e
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
duplicate_dim
=
[]
duplicate_flag
=
False
...
...
@@ -130,7 +130,7 @@ class TraceFlow(object):
all_node_info
,
next_node_list
,
):
arg_idx
=
find_idx_by_name
(
arg_node
.
name
,
self
.
trace_inde
x
.
node_list
)
arg_idx
=
find_idx_by_name
(
arg_node
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
# arg in chunk range or be inputs
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
return
True
...
...
@@ -171,7 +171,7 @@ class TraceFlow(object):
def
_get_all_node_info
(
self
,
end_dim
,
start_idx
,
end_idx
):
cur_node_list
=
[
self
.
trace_inde
x
.
node_list
[
end_idx
]
self
.
trace_ind
ic
e
.
node_list
[
end_idx
]
]
# start from the last node
all_node_info
=
{
cur_node_list
[
0
]:
{
"chunk_dim"
:
end_dim
,
"fix_dim"
:
[]}}
...
...
@@ -183,10 +183,10 @@ class TraceFlow(object):
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
if
cur_node_chunk_dim
:
cur_node_compute
=
self
.
trace_inde
x
.
_find_compute_trace_from_node
(
cur_node_compute
=
self
.
trace_ind
ic
e
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
trace_inde
x
.
_find_source_trace_from_node
(
cur_node_source
=
self
.
trace_ind
ic
e
.
_find_source_trace_from_node
(
cur_node
)
else
:
...
...
@@ -220,7 +220,7 @@ class TraceFlow(object):
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_inde
x
.
node_list
arg
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
<
end_idx
):
...
...
@@ -250,16 +250,16 @@ class TraceFlow(object):
for
input_node
in
inputs
:
input_dict
=
{}
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_inde
x
.
node_list
input_node
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
for
user
in
input_node
.
users
.
keys
():
if
is_non_compute_node
(
user
):
continue
user_idx
=
find_idx_by_name
(
user
.
name
,
self
.
trace_inde
x
.
node_list
)
user_idx
=
find_idx_by_name
(
user
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
if
chunk_dim
is
not
None
:
user_source
=
self
.
trace_inde
x
.
_find_source_trace_from_node
(
user_source
=
self
.
trace_ind
ic
e
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
if
input_node_idx
in
user_source
:
...
...
@@ -282,7 +282,7 @@ class TraceFlow(object):
if
node_info
[
"chunk_dim"
]
is
None
:
maybe_prepose_nodes
.
append
(
node
)
maybe_prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_inde
x
.
node_list
),
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_ind
ic
e
.
node_list
),
reverse
=
True
,
)
# from last node to first node
prepose_nodes
=
[]
...
...
@@ -308,7 +308,7 @@ class TraceFlow(object):
if
not
(
start_idx
<=
find_idx_by_name
(
cur_prepose_node_arg
.
name
,
self
.
trace_inde
x
.
node_list
cur_prepose_node_arg
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
<
end_idx
):
...
...
@@ -336,14 +336,14 @@ class TraceFlow(object):
maybe_prepose_nodes
.
remove
(
n
)
# sort by index
prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_inde
x
.
node_list
)
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_ind
ic
e
.
node_list
)
)
return
prepose_nodes
def
_get_non_chunk_inputs
(
self
,
chunk_info
,
start_idx
,
end_idx
):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list
=
self
.
trace_inde
x
.
node_list
[
start_idx
:
end_idx
+
1
]
chunk_node_list
=
self
.
trace_ind
ic
e
.
node_list
[
start_idx
:
end_idx
+
1
]
# also need to get some prepose node's arg out of non_chunk_inputs
for
n
in
chunk_info
[
"args"
][
"prepose_nodes"
]:
chunk_node_list
.
remove
(
n
)
...
...
@@ -355,7 +355,7 @@ class TraceFlow(object):
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
trace_inde
x
.
node_list
[
start_idx
:
end_idx
+
1
]
self
.
trace_ind
ic
e
.
node_list
[
start_idx
:
end_idx
+
1
]
)
# only single ouput
if
len
(
outputs
)
>
1
:
...
...
@@ -403,10 +403,10 @@ class TraceFlow(object):
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]
]
for
node
in
self
.
trace_inde
x
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
for
node
in
self
.
trace_ind
ic
e
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
trace_inde
x
.
idx_view_list
[
node
]
reshape_log
=
self
.
trace_ind
ic
e
.
idx_view_list
[
node
]
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
reshape_size
[
node
.
name
]
=
{}
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
...
...
colossalai/autochunk/trace_inde
x
.py
→
colossalai/autochunk/trace_ind
ic
e.py
View file @
0ea903b9
...
...
@@ -6,7 +6,7 @@ from .utils import (
)
class
TraceInde
x
(
object
):
class
TraceInd
ic
e
(
object
):
def
__init__
(
self
,
node_list
)
->
None
:
self
.
node_list
=
node_list
self
.
idx_trace_list
=
self
.
_init_idx_trace_list
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment