Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ae27a8b2
"runtime/rust/src/vscode:/vscode.git/clone" did not exist on "9d6643b7a59220fc4f3ef599c002241dd0bf9965"
Commit
ae27a8b2
authored
Jan 06, 2023
by
oahzxl
Browse files
seperate flow tracer
parent
fd87d78a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
9 deletions
+15
-9
colossalai/autochunk/index_tracer.py
colossalai/autochunk/index_tracer.py
+15
-9
No files found.
colossalai/autochunk/index_tracer.py
View file @
ae27a8b2
...
...
@@ -745,14 +745,7 @@ class IndexTracer(object):
next_node_list
.
append
(
arg_node
)
return
True
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
# only single ouput
if
len
(
outputs
)
>
1
:
return
None
def
_get_all_node_info
(
self
,
end_dim
,
start_idx
,
end_idx
):
cur_node_list
=
[
self
.
node_list
[
end_idx
]]
# start from the last node
all_node_info
=
{
cur_node_list
[
0
]:
{
"chunk_dim"
:
end_dim
,
"fix_dim"
:
[]}}
...
...
@@ -763,7 +756,6 @@ class IndexTracer(object):
# get cur node info
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
cur_node_idx
=
find_idx_by_name
(
cur_node
.
name
,
self
.
node_list
)
if
cur_node_chunk_dim
:
cur_node_compute
=
self
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
_find_source_trace_from_node
(
cur_node
)
...
...
@@ -818,6 +810,20 @@ class IndexTracer(object):
else
:
raise
NotImplementedError
()
cur_node_list
=
next_node_list
return
all_node_info
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
# only single ouput
if
len
(
outputs
)
>
1
:
return
None
# get every node's chunk dim and fix dim
all_node_info
=
self
.
_get_all_node_info
(
end_dim
,
start_idx
,
end_idx
)
if
all_node_info
is
None
:
return
None
inputs_dim
=
[]
remove_inputs
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment