Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ded10056
Commit
ded10056
authored
Dec 21, 2022
by
oahzxl
Browse files
format code
parent
d361d533
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
122 additions
and
62 deletions
+122
-62
chunk_codegen.py
chunk_codegen.py
+122
-62
No files found.
chunk_codegen.py
View file @
ded10056
...
@@ -144,7 +144,9 @@ class IndexTracer(object):
...
@@ -144,7 +144,9 @@ class IndexTracer(object):
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]
=
[
node_from_dim
]
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]
=
[
node_from_dim
]
else
:
else
:
if
node_from_dim
not
in
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]:
if
node_from_dim
not
in
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]:
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
].
append
(
node_from_dim
)
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
].
append
(
node_from_dim
)
# update inputs source
# update inputs source
node_to_trace
[
"source"
][
node_to_dim
].
update
(
node_to_trace
[
"source"
][
node_to_dim
].
update
(
node_from_trace
[
"source"
][
node_from_dim
]
node_from_trace
[
"source"
][
node_from_dim
]
...
@@ -745,7 +747,6 @@ class IndexTracer(object):
...
@@ -745,7 +747,6 @@ class IndexTracer(object):
return
True
return
True
class
FlowTracer
(
object
):
class
FlowTracer
(
object
):
def
__init__
(
self
,
gm
)
->
None
:
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
gm
=
gm
...
@@ -856,7 +857,9 @@ class FlowTracer(object):
...
@@ -856,7 +857,9 @@ class FlowTracer(object):
)
)
return
self
.
flow_trace
return
self
.
flow_trace
def
_detect_flow
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
:
IndexTracer
):
def
_detect_flow
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
:
IndexTracer
):
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
)
...
@@ -945,8 +948,10 @@ class FlowTracer(object):
...
@@ -945,8 +948,10 @@ class FlowTracer(object):
for
i
in
remove_inputs
:
for
i
in
remove_inputs
:
if
i
in
chunk_info
[
"inputs"
]:
if
i
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs"
].
remove
(
i
)
chunk_info
[
"inputs"
].
remove
(
i
)
duplicate_result
,
duplicate_dim
=
index_tracer
.
check_index_duplicate
(
chunk_info
,
return_dim
=
True
)
duplicate_result
,
duplicate_dim
=
index_tracer
.
check_index_duplicate
(
chunk_info
,
return_dim
=
True
)
# we need to log input nodes to avoid deleteing them in the loop
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
...
@@ -958,15 +963,25 @@ class FlowTracer(object):
...
@@ -958,15 +963,25 @@ class FlowTracer(object):
return
flow_block
,
chunk_info
return
flow_block
,
chunk_info
def
_assgin_single_node_flow
(
self
,
arg_node
,
start_idx
,
end_idx
,
def
_assgin_single_node_flow
(
inputs
,
index_tracer
,
cur_node_dim
,
self
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
arg_node
,
next_node_list
):
start_idx
,
end_idx
,
inputs
,
index_tracer
,
cur_node_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
,
):
arg_idx
=
_find_idx_by_name
(
arg_node
.
name
,
index_tracer
.
nodes_list
)
arg_idx
=
_find_idx_by_name
(
arg_node
.
name
,
index_tracer
.
nodes_list
)
# arg in chunk range or be inputs
# arg in chunk range or be inputs
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
return
True
return
True
# find arg dim
# find arg dim
if
cur_node_dim
is
not
None
:
if
cur_node_dim
is
not
None
:
# dim is computed
# dim is computed
...
@@ -978,7 +993,7 @@ class FlowTracer(object):
...
@@ -978,7 +993,7 @@ class FlowTracer(object):
arg_dim
=
cur_node_source
[
cur_node_dim
][
arg_idx
][
0
]
arg_dim
=
cur_node_source
[
cur_node_dim
][
arg_idx
][
0
]
else
:
else
:
arg_dim
=
None
arg_dim
=
None
# get fix dim
# get fix dim
arg_fix_dim
=
[]
arg_fix_dim
=
[]
if
cur_node_dim
is
not
None
:
if
cur_node_dim
is
not
None
:
...
@@ -986,44 +1001,52 @@ class FlowTracer(object):
...
@@ -986,44 +1001,52 @@ class FlowTracer(object):
fix_dim_source
=
cur_node_source
[
i
]
fix_dim_source
=
cur_node_source
[
i
]
if
arg_idx
in
fix_dim_source
:
if
arg_idx
in
fix_dim_source
:
arg_fix_dim
.
append
(
fix_dim_source
[
arg_idx
][
0
])
arg_fix_dim
.
append
(
fix_dim_source
[
arg_idx
][
0
])
# if already in node_info, arg dim must be same
# if already in node_info, arg dim must be same
if
arg_node
in
all_node_info
:
if
arg_node
in
all_node_info
:
if
all_node_info
[
arg_node
]
!=
arg_dim
:
if
all_node_info
[
arg_node
]
!=
arg_dim
:
return
False
return
False
all_node_info
[
arg_node
][
'fix_dim'
]
=
list
(
set
(
all_node_info
[
arg_node
][
'fix_dim'
]
+
arg_fix_dim
))
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
)
)
# else add it to list
# else add it to list
else
:
else
:
all_node_info
[
arg_node
]
=
{
'
chunk_dim
'
:
arg_dim
,
'
fix_dim
'
:
arg_fix_dim
}
all_node_info
[
arg_node
]
=
{
"
chunk_dim
"
:
arg_dim
,
"
fix_dim
"
:
arg_fix_dim
}
next_node_list
.
append
(
arg_node
)
next_node_list
.
append
(
arg_node
)
return
True
return
True
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
:
IndexTracer
):
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
:
IndexTracer
):
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
)
# only single ouput
# only single ouput
if
len
(
outputs
)
>
1
:
if
len
(
outputs
)
>
1
:
return
None
return
None
cur_node_list
=
[
index_tracer
.
nodes_list
[
end_idx
]]
# start from the last node
cur_node_list
=
[
index_tracer
.
nodes_list
[
end_idx
]]
# start from the last node
all_node_info
=
{
cur_node_list
[
0
]:
{
'
chunk_dim
'
:
end_dim
,
'
fix_dim
'
:
[]}}
all_node_info
=
{
cur_node_list
[
0
]:
{
"
chunk_dim
"
:
end_dim
,
"
fix_dim
"
:
[]}}
while
len
(
cur_node_list
)
>
0
:
while
len
(
cur_node_list
)
>
0
:
next_node_list
=
[]
next_node_list
=
[]
for
cur_node
in
cur_node_list
:
for
cur_node
in
cur_node_list
:
# get cur node info
# get cur node info
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
'
chunk_dim
'
]
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"
chunk_dim
"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
'
fix_dim
'
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"
fix_dim
"
]
cur_node_idx
=
_find_idx_by_name
(
cur_node
.
name
,
index_tracer
.
nodes_list
)
cur_node_idx
=
_find_idx_by_name
(
cur_node
.
name
,
index_tracer
.
nodes_list
)
if
cur_node_chunk_dim
:
if
cur_node_chunk_dim
:
cur_node_compute
=
index_tracer
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_compute
=
index_tracer
.
_find_compute_trace_from_node
(
cur_node_source
=
index_tracer
.
_find_source_trace_from_node
(
cur_node
)
cur_node
)
cur_node_source
=
index_tracer
.
_find_source_trace_from_node
(
cur_node
)
else
:
else
:
cur_node_compute
=
cur_node_source
=
None
cur_node_compute
=
cur_node_source
=
None
# get all valid args
# get all valid args
arg_list
=
[]
arg_list
=
[]
for
arg
in
cur_node
.
args
:
for
arg
in
cur_node
.
args
:
...
@@ -1032,20 +1055,33 @@ class FlowTracer(object):
...
@@ -1032,20 +1055,33 @@ class FlowTracer(object):
if
_is_non_compute_node
(
arg
):
if
_is_non_compute_node
(
arg
):
continue
continue
arg_list
.
append
(
arg
)
arg_list
.
append
(
arg
)
flow_flag
=
self
.
_assgin_single_node_flow
(
arg
,
start_idx
,
end_idx
,
flow_flag
=
self
.
_assgin_single_node_flow
(
inputs
,
index_tracer
,
cur_node_chunk_dim
,
arg
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
start_idx
,
next_node_list
)
end_idx
,
inputs
,
index_tracer
,
cur_node_chunk_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
,
)
if
flow_flag
==
False
:
if
flow_flag
==
False
:
return
None
return
None
if
len
(
arg_list
)
==
2
:
if
len
(
arg_list
)
==
2
:
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
]):
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
]):
for
arg
in
arg_list
:
for
arg
in
arg_list
:
if
not
(
start_idx
<=
_find_idx_by_name
(
arg
.
name
,
index_tracer
.
nodes_list
)
<
end_idx
):
if
not
(
start_idx
<=
_find_idx_by_name
(
arg
.
name
,
index_tracer
.
nodes_list
)
<
end_idx
):
continue
continue
arg_chunk_dim
=
all_node_info
[
arg
][
'
chunk_dim
'
]
arg_chunk_dim
=
all_node_info
[
arg
][
"
chunk_dim
"
]
arg_fix_dim
=
all_node_info
[
arg
][
'
fix_dim
'
]
arg_fix_dim
=
all_node_info
[
arg
][
"
fix_dim
"
]
arg_shape
=
_get_node_shape
(
arg
)
arg_shape
=
_get_node_shape
(
arg
)
# add all dim as fix dim except chunk dim
# add all dim as fix dim except chunk dim
for
i
,
shape
in
enumerate
(
arg_shape
):
for
i
,
shape
in
enumerate
(
arg_shape
):
...
@@ -1061,7 +1097,7 @@ class FlowTracer(object):
...
@@ -1061,7 +1097,7 @@ class FlowTracer(object):
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
cur_node_list
=
next_node_list
cur_node_list
=
next_node_list
inputs_dim
=
[]
inputs_dim
=
[]
remove_inputs
=
[]
remove_inputs
=
[]
for
input_node
in
inputs
:
for
input_node
in
inputs
:
...
@@ -1071,7 +1107,7 @@ class FlowTracer(object):
...
@@ -1071,7 +1107,7 @@ class FlowTracer(object):
continue
continue
user_idx
=
_find_idx_by_name
(
user
.
name
,
self
.
node_list
)
user_idx
=
_find_idx_by_name
(
user
.
name
,
self
.
node_list
)
if
start_idx
<=
user_idx
<=
end_idx
:
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
'
chunk_dim
'
]
chunk_dim
=
all_node_info
[
user
][
"
chunk_dim
"
]
if
chunk_dim
is
not
None
:
if
chunk_dim
is
not
None
:
input_dict
[
user_idx
]
=
chunk_dim
input_dict
[
user_idx
]
=
chunk_dim
if
len
(
input_dict
)
==
0
:
if
len
(
input_dict
)
==
0
:
...
@@ -1081,7 +1117,7 @@ class FlowTracer(object):
...
@@ -1081,7 +1117,7 @@ class FlowTracer(object):
for
i
in
remove_inputs
:
for
i
in
remove_inputs
:
if
i
in
inputs
:
if
i
in
inputs
:
inputs
.
remove
(
i
)
inputs
.
remove
(
i
)
chunk_info
=
{
chunk_info
=
{
"region"
:
(
start_idx
,
end_idx
),
"region"
:
(
start_idx
,
end_idx
),
"inputs"
:
inputs
,
"inputs"
:
inputs
,
...
@@ -1091,7 +1127,7 @@ class FlowTracer(object):
...
@@ -1091,7 +1127,7 @@ class FlowTracer(object):
"outputs_dim"
:
end_dim
,
"outputs_dim"
:
end_dim
,
"args"
:
{},
"args"
:
{},
}
}
# we need to log input nodes to avoid deleteing them in the loop
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
self
.
node_list
[
start_idx
:
end_idx
+
1
]
...
@@ -1129,7 +1165,7 @@ class MemoryEstimator(object):
...
@@ -1129,7 +1165,7 @@ class MemoryEstimator(object):
def
_add_active_node
(
self
,
n
,
active_list
):
def
_add_active_node
(
self
,
n
,
active_list
):
new_active
=
self
.
_get_output_node
(
n
)[
1
]
new_active
=
self
.
_get_output_node
(
n
)[
1
]
if
n
.
op
==
'
placeholder
'
:
if
n
.
op
==
"
placeholder
"
:
new_active
.
append
(
n
.
name
)
new_active
.
append
(
n
.
name
)
for
i
in
new_active
:
for
i
in
new_active
:
if
i
not
in
active_list
:
if
i
not
in
active_list
:
...
@@ -1168,12 +1204,16 @@ class MemoryEstimator(object):
...
@@ -1168,12 +1204,16 @@ class MemoryEstimator(object):
for
i
in
delete_node
:
for
i
in
delete_node
:
if
i
in
active_list
:
if
i
in
active_list
:
active_list
.
remove
(
i
)
active_list
.
remove
(
i
)
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
nodes_to_delete
=
[]
nodes_to_delete
=
[]
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users_idx
=
[
_find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
chunk_input_users_idx
=
[
_find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
chunk_input
not
in
nodes_to_delete
:
if
chunk_input
not
in
nodes_to_delete
:
nodes_to_delete
.
append
(
chunk_input
)
nodes_to_delete
.
append
(
chunk_input
)
...
@@ -1226,7 +1266,9 @@ class MemoryEstimator(object):
...
@@ -1226,7 +1266,9 @@ class MemoryEstimator(object):
for
(
input_node
,
input_node_dim
)
in
zip
(
chunk_inputs
,
chunk_inputs_dim
):
for
(
input_node
,
input_node_dim
)
in
zip
(
chunk_inputs
,
chunk_inputs_dim
):
for
k
,
v
in
input_node_dim
.
items
():
for
k
,
v
in
input_node_dim
.
items
():
# TODO: inherit dim should be list too, int now
# TODO: inherit dim should be list too, int now
inherit_dim
=
self
.
index_tracer
.
_find_inherit_dim
(
input_node
,
v
,
self
.
index_tracer
.
nodes_list
[
k
])
inherit_dim
=
self
.
index_tracer
.
_find_inherit_dim
(
input_node
,
v
,
self
.
index_tracer
.
nodes_list
[
k
]
)
if
k
==
_find_idx_by_name
(
node
.
name
,
self
.
index_tracer
.
nodes_list
):
if
k
==
_find_idx_by_name
(
node
.
name
,
self
.
index_tracer
.
nodes_list
):
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
inherit_dim
]
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
inherit_dim
]
return
chunk_ratio
return
chunk_ratio
...
@@ -1234,7 +1276,7 @@ class MemoryEstimator(object):
...
@@ -1234,7 +1276,7 @@ class MemoryEstimator(object):
if
k
in
source
and
inherit_dim
in
source
[
k
]:
if
k
in
source
and
inherit_dim
in
source
[
k
]:
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
dim
]
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
dim
]
return
chunk_ratio
return
chunk_ratio
return
1.
return
1.
0
def
_get_chunk_delete_node_size
(
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
...
@@ -1295,7 +1337,7 @@ class MemoryEstimator(object):
...
@@ -1295,7 +1337,7 @@ class MemoryEstimator(object):
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_size
=
1
chunk_size
=
1
chunk_inputs_names
=
[]
chunk_inputs_names
=
[]
if
use_chunk
:
if
use_chunk
:
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
...
@@ -1313,12 +1355,17 @@ class MemoryEstimator(object):
...
@@ -1313,12 +1355,17 @@ class MemoryEstimator(object):
if
use_chunk
and
idx
in
chunk_starts
:
if
use_chunk
and
idx
in
chunk_starts
:
chunk_within
=
True
chunk_within
=
True
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
])
/
(
1024
**
2
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
]
)
/
(
1024
**
2
)
# determine chunk ratio for current node
# determine chunk ratio for current node
if
chunk_within
:
if
chunk_within
:
chunk_ratio
=
self
.
_get_chunk_ratio
(
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_dim
[
chunk_region_idx
],
chunk_size
node
,
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_dim
[
chunk_region_idx
],
chunk_size
,
)
)
# if node is placeholder, just add the size of the node
# if node is placeholder, just add the size of the node
...
@@ -1353,18 +1400,18 @@ class MemoryEstimator(object):
...
@@ -1353,18 +1400,18 @@ class MemoryEstimator(object):
/
(
1024
**
2
)
/
(
1024
**
2
)
)
)
# delete unused vars not in chunk_input_list
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
# we can't delete input nodes until chunk ends
if
chunk_within
:
if
chunk_within
:
act_memory
-=
self
.
_get_chunk_delete_node_size
(
act_memory
-=
self
.
_get_chunk_delete_node_size
(
node
,
node
,
user_to_last_uses_no_free_var
,
user_to_last_uses_no_free_var
,
chunk_ratio
,
chunk_ratio
,
chunk_inputs_names
chunk_inputs_names
,
)
/
(
1024
**
2
)
)
/
(
1024
**
2
)
else
:
else
:
act_memory
-=
(
self
.
_get_delete_node_size
(
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
)
/
(
1024
**
2
)
)
)
/
(
1024
**
2
)
# log active node, only effective without chunk
# log active node, only effective without chunk
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_add_active_node
(
node
,
active_node_list
)
...
@@ -1376,11 +1423,11 @@ class MemoryEstimator(object):
...
@@ -1376,11 +1423,11 @@ class MemoryEstimator(object):
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
)
act_memory
-=
self
.
_get_chunk_inputs_size
(
act_memory
-=
self
.
_get_chunk_inputs_size
(
chunk_inputs
[
chunk_region_idx
],
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
node_list
,
node_list
,
chunk_regions
[
chunk_region_idx
][
1
]
chunk_regions
[
chunk_region_idx
][
1
]
,
)
/
(
1024
**
2
)
)
/
(
1024
**
2
)
chunk_within
=
False
chunk_within
=
False
chunk_ratio
=
1
chunk_ratio
=
1
chunk_region_idx
=
None
chunk_region_idx
=
None
...
@@ -1436,7 +1483,7 @@ class ChunkRegionSearch(object):
...
@@ -1436,7 +1483,7 @@ class ChunkRegionSearch(object):
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
min_active_node_num
=
min
(
active_node_num
[
free_var_num
:])
min_active_node_num
=
min
(
active_node_num
[
free_var_num
:])
threshold
=
max
(
free_var_num
,
min_active_node_num
)
threshold
=
max
(
free_var_num
,
min_active_node_num
)
# from peak_node to free_var
# from peak_node to free_var
inside_flag
=
False
inside_flag
=
False
chunk_region_start
=
free_var_num
chunk_region_start
=
free_var_num
...
@@ -1494,7 +1541,12 @@ class ChunkRegionSearch(object):
...
@@ -1494,7 +1541,12 @@ class ChunkRegionSearch(object):
continue
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
if
start_idx
==
199
and
end_idx
==
229
and
start_dim
==
2
and
end_dim
==
2
:
if
(
start_idx
==
199
and
end_idx
==
229
and
start_dim
==
2
and
end_dim
==
2
):
print
(
1
)
print
(
1
)
self
.
flow_tracer
.
flow_search
(
self
.
flow_tracer
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
,
self
.
index_tracer
start_idx
,
start_dim
,
end_idx
,
end_dim
,
self
.
index_tracer
...
@@ -1576,7 +1628,7 @@ class ChunkRegionSearch(object):
...
@@ -1576,7 +1628,7 @@ class ChunkRegionSearch(object):
max_region_range
=
0
max_region_range
=
0
best_region
=
None
best_region
=
None
return
best_region
return
best_region
def
_is_legal_region
(
self
,
cur_chunk_info
,
chunk_infos
):
def
_is_legal_region
(
self
,
cur_chunk_info
,
chunk_infos
):
(
chunk_region_start
,
chunk_region_end
)
=
cur_chunk_info
[
"region"
]
(
chunk_region_start
,
chunk_region_end
)
=
cur_chunk_info
[
"region"
]
if
cur_chunk_info
in
chunk_infos
:
if
cur_chunk_info
in
chunk_infos
:
...
@@ -1585,11 +1637,13 @@ class ChunkRegionSearch(object):
...
@@ -1585,11 +1637,13 @@ class ChunkRegionSearch(object):
return
False
return
False
for
i
in
chunk_infos
:
for
i
in
chunk_infos
:
region
=
i
[
"region"
]
region
=
i
[
"region"
]
if
not
((
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
if
not
(
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])):
(
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])
):
return
False
return
False
return
True
return
True
def
_step_search
(
self
,
mem_peak
,
active_node
,
chunk_regions
):
def
_step_search
(
self
,
mem_peak
,
active_node
,
chunk_regions
):
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
max_chunk_region
=
self
.
_search_max_chunk_region
(
max_chunk_region
=
self
.
_search_max_chunk_region
(
...
@@ -1600,7 +1654,9 @@ class ChunkRegionSearch(object):
...
@@ -1600,7 +1654,9 @@ class ChunkRegionSearch(object):
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
max_chunk_region
,
peak_node
)
)
best_chunk_region
=
self
.
_search_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
)
best_chunk_region
=
self
.
_search_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
)
return
best_chunk_region
return
best_chunk_region
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
...
@@ -1667,7 +1723,11 @@ def _gen_loop_end(
...
@@ -1667,7 +1723,11 @@ def _gen_loop_end(
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
)
context
=
" chunk_result%s = %s; %s = None
\n
"
%
(
chunk_slice
,
chunk_outputs_name
,
chunk_outputs_name
)
context
=
" chunk_result%s = %s; %s = None
\n
"
%
(
chunk_slice
,
chunk_outputs_name
,
chunk_outputs_name
,
)
context
+=
(
context
+=
(
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment