Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f4a1607e
Commit
f4a1607e
authored
Jan 06, 2023
by
oahzxl
Browse files
seperate input node dim search
parent
ae27a8b2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
14 deletions
+21
-14
colossalai/autochunk/index_tracer.py
colossalai/autochunk/index_tracer.py
+21
-14
No files found.
colossalai/autochunk/index_tracer.py
View file @
f4a1607e
...
...
@@ -812,19 +812,7 @@ class IndexTracer(object):
cur_node_list
=
next_node_list
return
all_node_info
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
# only single ouput
if
len
(
outputs
)
>
1
:
return
None
# get every node's chunk dim and fix dim
all_node_info
=
self
.
_get_all_node_info
(
end_dim
,
start_idx
,
end_idx
)
if
all_node_info
is
None
:
return
None
def
_get_input_nodes_dim
(
self
,
inputs
,
start_idx
,
end_idx
,
all_node_info
):
inputs_dim
=
[]
remove_inputs
=
[]
for
input_node
in
inputs
:
...
...
@@ -841,7 +829,7 @@ class IndexTracer(object):
if
input_node_idx
in
user_source
:
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
else
:
return
None
return
None
,
None
if
len
(
input_dict
)
==
0
:
remove_inputs
.
append
(
input_node
)
else
:
...
...
@@ -849,6 +837,25 @@ class IndexTracer(object):
for
i
in
remove_inputs
:
if
i
in
inputs
:
inputs
.
remove
(
i
)
return
inputs
,
inputs_dim
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
# only single ouput
if
len
(
outputs
)
>
1
:
return
None
# get every node's chunk dim and fix dim
all_node_info
=
self
.
_get_all_node_info
(
end_dim
,
start_idx
,
end_idx
)
if
all_node_info
is
None
:
return
None
# get input nodes' chunk dim
inputs
,
inputs_dim
=
self
.
_get_input_nodes_dim
(
inputs
,
start_idx
,
end_idx
,
all_node_info
)
if
inputs
is
None
:
return
None
chunk_info
=
{
"region"
:
(
start_idx
,
end_idx
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment