Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
774d34f1
Commit
774d34f1
authored
Dec 23, 2022
by
oahzxl
Browse files
refactor flow search
parent
ded10056
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
20 deletions
+58
-20
chunk_codegen.py
chunk_codegen.py
+58
-20
No files found.
chunk_codegen.py
View file @
774d34f1
...
@@ -1004,7 +1004,7 @@ class FlowTracer(object):
...
@@ -1004,7 +1004,7 @@ class FlowTracer(object):
# if already in node_info, arg dim must be same
# if already in node_info, arg dim must be same
if
arg_node
in
all_node_info
:
if
arg_node
in
all_node_info
:
if
all_node_info
[
arg_node
]
!=
arg_dim
:
if
all_node_info
[
arg_node
]
[
'chunk_dim'
]
!=
arg_dim
:
return
False
return
False
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
)
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
)
...
@@ -1128,14 +1128,68 @@ class FlowTracer(object):
...
@@ -1128,14 +1128,68 @@ class FlowTracer(object):
"args"
:
{},
"args"
:
{},
}
}
# move useless nodes ahead of loop
# get all possible prepose nodes
maybe_prepose_nodes
=
[]
for
node
,
node_info
in
all_node_info
.
items
():
if
node_info
[
'chunk_dim'
]
is
None
:
maybe_prepose_nodes
.
append
(
node
)
maybe_prepose_nodes
.
sort
(
key
=
lambda
x
:
_find_idx_by_name
(
x
.
name
,
index_tracer
.
nodes_list
),
reverse
=
True
)
# from last node to first node
prepose_nodes
=
[]
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while
len
(
maybe_prepose_nodes
)
>
0
:
tmp_cur_prepose_nodes
=
[
maybe_prepose_nodes
[
0
]]
tmp_cur_related_prepose_nodes
=
[]
prepose_flag
=
True
# loop cur node's all arg until out of chunk
while
len
(
tmp_cur_prepose_nodes
)
>
0
:
tmp_next_prepose_nodes
=
[]
tmp_cur_related_prepose_nodes
.
extend
(
tmp_cur_prepose_nodes
)
for
cur_prepose_node
in
tmp_cur_prepose_nodes
:
for
cur_prepose_node_arg
in
cur_prepose_node
.
args
:
if
type
(
cur_prepose_node_arg
)
!=
type
(
cur_prepose_node
):
continue
# out of loop
if
not
(
start_idx
<=
_find_idx_by_name
(
cur_prepose_node_arg
.
name
,
self
.
node_list
)
<
end_idx
):
continue
# compute op in loop
elif
cur_prepose_node_arg
in
all_node_info
:
if
all_node_info
[
cur_prepose_node_arg
][
'chunk_dim'
]
is
None
:
tmp_next_prepose_nodes
.
append
(
cur_prepose_node_arg
)
else
:
prepose_flag
=
False
break
;
break
;
break
# non compute op
else
:
tmp_next_prepose_nodes
.
append
(
cur_prepose_node_arg
)
tmp_cur_prepose_nodes
=
tmp_next_prepose_nodes
if
prepose_flag
==
False
:
maybe_prepose_nodes
.
remove
(
maybe_prepose_nodes
[
0
])
continue
else
:
for
n
in
tmp_cur_related_prepose_nodes
:
if
n
not
in
prepose_nodes
:
prepose_nodes
.
append
(
n
)
if
n
in
maybe_prepose_nodes
:
maybe_prepose_nodes
.
remove
(
n
)
# sort by index
prepose_nodes
.
sort
(
key
=
lambda
x
:
_find_idx_by_name
(
x
.
name
,
index_tracer
.
nodes_list
))
chunk_info
[
"args"
][
"prepose_nodes"
]
=
prepose_nodes
# we need to log input nodes to avoid deleteing them in the loop
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list
=
self
.
node_list
[
start_idx
:
end_idx
+
1
]
# also need to get some prepose node's arg out of non_chunk_inputs
for
n
in
prepose_nodes
:
chunk_node_list
.
remove
(
n
)
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
chunk_
node_list
)
)
for
i
in
non_chunk_inputs
:
for
i
in
non_chunk_inputs
:
if
i
not
in
chunk_info
[
"inputs"
]:
if
i
not
in
chunk_info
[
"inputs"
]
and
i
not
in
prepose_nodes
:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
return
chunk_info
return
chunk_info
...
@@ -1541,16 +1595,6 @@ class ChunkRegionSearch(object):
...
@@ -1541,16 +1595,6 @@ class ChunkRegionSearch(object):
continue
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
if
(
start_idx
==
199
and
end_idx
==
229
and
start_dim
==
2
and
end_dim
==
2
):
print
(
1
)
self
.
flow_tracer
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
,
self
.
index_tracer
)
# dim size cannot be 1
# dim size cannot be 1
if
(
if
(
_get_node_shape
(
end_node
)[
end_dim
]
==
1
_get_node_shape
(
end_node
)[
end_dim
]
==
1
...
@@ -1567,12 +1611,6 @@ class ChunkRegionSearch(object):
...
@@ -1567,12 +1611,6 @@ class ChunkRegionSearch(object):
start_idx
,
end_dim
,
end_node
,
end_idx
start_idx
,
end_dim
,
end_node
,
end_idx
):
):
continue
continue
# detect flow meet
# flow_block, chunk_info = self.flow_tracer._detect_flow(
# start_idx, start_dim, end_idx, end_dim, self.index_tracer
# )
# if flow_block:
# continue
# flow search
# flow search
chunk_info
=
self
.
flow_tracer
.
flow_search
(
chunk_info
=
self
.
flow_tracer
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
,
self
.
index_tracer
start_idx
,
start_dim
,
end_idx
,
end_dim
,
self
.
index_tracer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment