Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
884a228e
Commit
884a228e
authored
Dec 23, 2022
by
oahzxl
Browse files
reorder nodes
parent
e0ae68e7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
101 additions
and
26 deletions
+101
-26
chunk_codegen.py
chunk_codegen.py
+101
-26
No files found.
chunk_codegen.py
View file @
884a228e
...
...
@@ -71,6 +71,7 @@ class IndexTracer(object):
self
.
idx_trace_equal
=
[]
self
.
idx_view_list
=
[]
self
.
idx_count
=
-
1
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
idx_trace_list
))}
def
_init_idx_trace_list
(
self
):
idx_trace_list
=
[]
...
...
@@ -973,6 +974,91 @@ class IndexTracer(object):
return
chunk_info
def
_get_reorder_map
(
self
,
chunk_info
):
reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
node_list
))}
chunk_region_start
=
chunk_info
[
"region"
][
0
]
chunk_region_end
=
chunk_info
[
"region"
][
1
]
chunk_prepose_nodes
=
chunk_info
[
"args"
][
"prepose_nodes"
]
chunk_prepose_nodes_idx
=
[
_find_idx_by_name
(
i
.
name
,
self
.
node_list
)
for
i
in
chunk_prepose_nodes
]
# put prepose nodes ahead
for
idx
,
n
in
enumerate
(
chunk_prepose_nodes
):
n_idx
=
chunk_prepose_nodes_idx
[
idx
]
reorder_map
[
n_idx
]
=
chunk_region_start
+
idx
# put other nodes after prepose nodes
for
n
in
self
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
if
n
in
chunk_prepose_nodes
:
continue
n_idx
=
_find_idx_by_name
(
n
.
name
,
self
.
node_list
)
pos
=
sum
([
n_idx
<
i
for
i
in
chunk_prepose_nodes_idx
])
reorder_map
[
n_idx
]
=
n_idx
+
pos
return
reorder_map
def
_reorder_chunk_info
(
self
,
chunk_info
,
reorder_map
):
# update chunk info
chunk_info
[
"region"
]
=
(
chunk_info
[
"region"
][
0
]
+
len
(
chunk_info
[
"args"
][
"prepose_nodes"
]),
chunk_info
[
"region"
][
1
],
)
for
idx
,
input_dim
in
enumerate
(
chunk_info
[
"inputs_dim"
]):
new_input_dim
=
{}
for
k
,
v
in
input_dim
.
items
():
new_input_dim
[
reorder_map
[
k
]]
=
v
chunk_info
[
"inputs_dim"
][
idx
]
=
new_input_dim
return
chunk_info
def
_update_all_reorder_map
(
self
,
reorder_map
):
for
origin_idx
,
map_idx
in
self
.
all_reorder_map
.
items
():
self
.
all_reorder_map
[
origin_idx
]
=
reorder_map
[
map_idx
]
def
_reorder_self_node_list
(
self
,
reorder_map
):
new_node_list
=
[
None
for
_
in
range
(
len
(
self
.
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
self
.
node_list
[
old_idx
]
self
.
node_list
=
new_node_list
def
_reorder_idx_trace
(
self
,
reorder_map
):
# reorder list
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
idx_trace_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_idx_trace_list
[
new_idx
]
=
self
.
idx_trace_list
[
old_idx
]
self
.
idx_trace_list
=
new_idx_trace_list
# update compute
for
idx_trace
in
self
.
idx_trace_list
:
compute
=
idx_trace
[
"compute"
]
for
dim_compute
in
compute
:
for
idx
,
i
in
enumerate
(
dim_compute
):
dim_compute
[
idx
]
=
reorder_map
[
i
]
# update source
for
idx_trace
in
self
.
idx_trace_list
:
source
=
idx_trace
[
"source"
]
for
dim_idx
,
dim_source
in
enumerate
(
source
):
new_dim_source
=
{}
for
k
,
v
in
dim_source
.
items
():
new_dim_source
[
reorder_map
[
k
]]
=
v
source
[
dim_idx
]
=
new_dim_source
def
reorder_all
(
self
,
chunk_info
):
if
chunk_info
is
None
:
return
chunk_info
if
len
(
chunk_info
[
"args"
][
"prepose_nodes"
])
==
0
:
return
chunk_info
reorder_map
=
self
.
_get_reorder_map
(
chunk_info
)
self
.
_update_all_reorder_map
(
reorder_map
)
self
.
_reorder_idx_trace
(
reorder_map
)
self
.
_reorder_self_node_list
(
reorder_map
)
chunk_info
=
self
.
_reorder_chunk_info
(
chunk_info
,
reorder_map
)
return
chunk_info
def
reorder_node_list
(
self
,
node_list
):
new_node_list
=
[
None
for
_
in
range
(
len
(
node_list
))]
for
old_idx
,
new_idx
in
self
.
all_reorder_map
.
items
():
new_node_list
[
new_idx
]
=
node_list
[
old_idx
]
return
new_node_list
class
MemoryEstimator
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
...
...
@@ -1476,6 +1562,7 @@ class ChunkRegionSearch(object):
best_chunk_region
=
self
.
_search_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
)
best_chunk_region
=
self
.
index_tracer
.
reorder_all
(
best_chunk_region
)
return
best_chunk_region
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
...
...
@@ -1670,8 +1757,7 @@ def emit_code_with_chunk(
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_search
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_search
]
chunk_prepose_nodes
=
[
i
[
"args"
][
"prepose_nodes"
]
for
i
in
chunk_search
]
node_list
=
chunk_region_search
.
index_tracer
.
reorder_node_list
(
node_list
)
node_idx
=
0
region_idx
=
0
within_chunk_region
=
False
...
...
@@ -1682,12 +1768,6 @@ def emit_code_with_chunk(
if
node_idx
in
chunk_starts
:
within_chunk_region
=
True
region_idx
=
chunk_starts
.
index
(
node_idx
)
# add prepose nodes
for
i
in
chunk_prepose_nodes
[
region_idx
]:
prepose_node
=
node_list
[
_find_idx_by_name
(
i
.
name
,
node_list
)]
emit_node_func
(
prepose_node
,
body
)
delete_unused_value_func
(
prepose_node
,
body
,
chunk_inputs_names
)
# add for loop
body
.
append
(
_gen_loop_start
(
chunk_inputs
[
region_idx
],
...
...
@@ -1697,24 +1777,19 @@ def emit_code_with_chunk(
)
if
within_chunk_region
:
if
any
(
node
.
name
==
i
.
name
for
i
in
chunk_prepose_nodes
[
region_idx
]):
pass
else
:
emit_node_func
(
node
,
body
)
# replace input var with chunk var
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
chunk_slice
=
_gen_chunk_slice_dim
(
dim
,
"chunk_idx"
,
_get_node_shape
(
input_node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
body
[
-
1
]
=
" "
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
emit_node_func
(
node
,
body
)
# replace input var with chunk var
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
chunk_slice
=
_gen_chunk_slice_dim
(
dim
,
"chunk_idx"
,
_get_node_shape
(
input_node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
body
[
-
1
]
=
" "
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
else
:
emit_node_func
(
node
,
body
)
if
node_idx
not
in
chunk_inputs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment