Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d309e933
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "bbbcac26e80601790728b1e8b8a7595d4d89a7b4"
Commit
d309e933
authored
Dec 23, 2022
by
oahzxl
Browse files
adapt codegen to prepose node
parent
522f0174
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
18 deletions
+25
-18
chunk_codegen.py
chunk_codegen.py
+25
-18
No files found.
chunk_codegen.py
View file @
d309e933
...
@@ -1198,7 +1198,7 @@ class FlowTracer(object):
...
@@ -1198,7 +1198,7 @@ class FlowTracer(object):
chunk_node_list
.
remove
(
n
)
chunk_node_list
.
remove
(
n
)
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
chunk_node_list
)
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
chunk_node_list
)
for
i
in
non_chunk_inputs
:
for
i
in
non_chunk_inputs
:
if
i
not
in
chunk_info
[
"inputs"
]
and
i
not
in
prepose_nodes
:
if
i
not
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
return
chunk_info
return
chunk_info
...
@@ -1425,6 +1425,7 @@ class MemoryEstimator(object):
...
@@ -1425,6 +1425,7 @@ class MemoryEstimator(object):
)
/
(
1024
**
2
)
)
/
(
1024
**
2
)
# determine chunk ratio for current node
# determine chunk ratio for current node
# TODO: adapt to prepose node memory
if
chunk_within
:
if
chunk_within
:
chunk_ratio
=
self
.
_get_chunk_ratio
(
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
node
,
...
@@ -1602,7 +1603,6 @@ class ChunkRegionSearch(object):
...
@@ -1602,7 +1603,6 @@ class ChunkRegionSearch(object):
chunk_infos
=
[]
chunk_infos
=
[]
for
end_dim
,
end_trace_idx
in
enumerate
(
end_trace
[
"idx"
]):
for
end_dim
,
end_trace_idx
in
enumerate
(
end_trace
[
"idx"
]):
if
len
(
start_traces
)
>
1
:
if
len
(
start_traces
)
>
1
:
# TODO: implement multi input chunk
continue
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
...
@@ -1831,7 +1831,6 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
...
@@ -1831,7 +1831,6 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
# if a node has a user node which is not in the node list
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
# we treat that user node as the node receiving the current node output
# TODO: it is unsafe to remove non compute node here
for
node
in
nodes
:
for
node
in
nodes
:
for
output_node
in
node
.
users
.
keys
():
for
output_node
in
node
.
users
.
keys
():
if
(
if
(
...
@@ -1900,6 +1899,8 @@ def emit_code_with_chunk(
...
@@ -1900,6 +1899,8 @@ def emit_code_with_chunk(
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_search
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_search
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_search
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_search
]
chunk_prepose_nodes
=
[
i
[
"args"
][
"prepose_nodes"
]
for
i
in
chunk_search
]
node_idx
=
0
node_idx
=
0
region_idx
=
0
region_idx
=
0
...
@@ -1911,7 +1912,11 @@ def emit_code_with_chunk(
...
@@ -1911,7 +1912,11 @@ def emit_code_with_chunk(
if
node_idx
in
chunk_starts
:
if
node_idx
in
chunk_starts
:
within_chunk_region
=
True
within_chunk_region
=
True
region_idx
=
chunk_starts
.
index
(
node_idx
)
region_idx
=
chunk_starts
.
index
(
node_idx
)
# add prepose nodes
for
i
in
chunk_prepose_nodes
[
region_idx
]:
prepose_node
=
node_list
[
_find_idx_by_name
(
i
.
name
,
node_list
)]
emit_node_func
(
prepose_node
,
body
)
delete_unused_value_func
(
prepose_node
,
body
,
chunk_inputs_names
)
# add for loop
# add for loop
body
.
append
(
body
.
append
(
_gen_loop_start
(
_gen_loop_start
(
...
@@ -1922,20 +1927,22 @@ def emit_code_with_chunk(
...
@@ -1922,20 +1927,22 @@ def emit_code_with_chunk(
)
)
if
within_chunk_region
:
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
if
any
(
node
.
name
==
i
.
name
for
i
in
chunk_prepose_nodes
[
region_idx
]):
# replace input var with chunk var
pass
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
else
:
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
emit_node_func
(
node
,
body
)
if
idx
==
node_idx
:
# replace input var with chunk var
chunk_slice
=
_gen_chunk_slice_dim
(
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
dim
,
"chunk_idx"
,
_get_node_shape
(
input_node
)
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
)
if
idx
==
node_idx
:
body
[
-
1
]
=
_replace_name
(
chunk_slice
=
_gen_chunk_slice_dim
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
dim
,
"chunk_idx"
,
_get_node_shape
(
input_node
)
)
)
body
[
-
1
]
=
" "
+
body
[
-
1
]
body
[
-
1
]
=
_replace_name
(
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
body
[
-
1
]
=
" "
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
else
:
else
:
emit_node_func
(
node
,
body
)
emit_node_func
(
node
,
body
)
if
node_idx
not
in
chunk_inputs
:
if
node_idx
not
in
chunk_inputs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment