Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e66a18a0
"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "30412866e0c6860de0787cd9c5e9a5bfffdc712c"
Commit
e66a18a0
authored
Dec 16, 2022
by
oahzxl
Browse files
optimise search
parent
e83e3c61
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
20 deletions
+47
-20
chunk_codegen.py
chunk_codegen.py
+47
-20
No files found.
chunk_codegen.py
View file @
e66a18a0
...
@@ -958,6 +958,8 @@ class MemoryEstimator(object):
...
@@ -958,6 +958,8 @@ class MemoryEstimator(object):
def
_add_active_node
(
self
,
n
,
active_list
):
def
_add_active_node
(
self
,
n
,
active_list
):
new_active
=
self
.
_get_output_node
(
n
)[
1
]
new_active
=
self
.
_get_output_node
(
n
)[
1
]
if
n
.
op
==
'placeholder'
:
new_active
.
append
(
n
.
name
)
for
i
in
new_active
:
for
i
in
new_active
:
if
i
not
in
active_list
:
if
i
not
in
active_list
:
active_list
.
append
(
i
)
active_list
.
append
(
i
)
...
@@ -965,7 +967,7 @@ class MemoryEstimator(object):
...
@@ -965,7 +967,7 @@ class MemoryEstimator(object):
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
,
to_keep
=
None
):
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
,
to_keep
=
None
):
delete_size
=
0
delete_size
=
0
delete_node
=
[]
delete_node
=
[]
if
user
.
op
not
in
(
"placeholder"
,
"output"
):
if
user
.
op
not
in
(
"output"
,
):
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
to_keep
is
not
None
:
if
to_keep
is
not
None
:
keep_list
=
[]
keep_list
=
[]
...
@@ -1258,24 +1260,30 @@ class ChunkRegionSearch(object):
...
@@ -1258,24 +1260,30 @@ class ChunkRegionSearch(object):
def
_search_max_chunk_region
(
self
,
active_node
,
peak_node
,
chunk_regions
):
def
_search_max_chunk_region
(
self
,
active_node
,
peak_node
,
chunk_regions
):
free_vars
=
self
.
_get_free_var
()
free_vars
=
self
.
_get_free_var
()
min_var
=
self
.
_get_min_free_var
(
active_node
,
free_vars
)
free_var_num
=
len
(
free_vars
)
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
min_active_node_num
=
min
(
active_node_num
[
free_var_num
:])
threshold
=
max
(
free_var_num
,
min_active_node_num
)
# from peak_node to free_var
# from peak_node to free_var
chunk_region_start
=
len
(
free_vars
)
inside_flag
=
False
chunk_region_start
=
free_var_num
for
i
in
range
(
peak_node
,
-
1
,
-
1
):
for
i
in
range
(
peak_node
,
-
1
,
-
1
):
if
len
(
active_node
[
i
])
==
min_var
:
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_start
=
i
+
1
chunk_region_start
=
i
+
1
break
break
if
i
in
free_vars
or
i
==
0
:
raise
RuntimeError
()
# from peak_node to len-2
# from peak_node to len-2
inside_flag
=
False
chunk_region_end
=
len
(
active_node
)
-
1
chunk_region_end
=
len
(
active_node
)
-
1
for
i
in
range
(
peak_node
,
len
(
active_node
)):
for
i
in
range
(
peak_node
,
len
(
active_node
)):
if
len
(
active_node
[
i
])
==
min_var
:
if
active_node_num
[
i
]
<=
threshold
:
inside_flag
=
True
if
inside_flag
and
active_node_num
[
i
]
>
threshold
:
chunk_region_end
=
i
chunk_region_end
=
i
break
break
if
i
in
free_vars
or
i
==
0
:
raise
RuntimeError
()
for
i
in
chunk_regions
:
for
i
in
chunk_regions
:
region
=
i
[
"region"
]
region
=
i
[
"region"
]
...
@@ -1374,15 +1382,34 @@ class ChunkRegionSearch(object):
...
@@ -1374,15 +1382,34 @@ class ChunkRegionSearch(object):
possible_chunk_region
.
extend
(
chunk_info
)
possible_chunk_region
.
extend
(
chunk_info
)
return
possible_chunk_region
return
possible_chunk_region
def
_search_best_chunk_region
(
self
,
possible_chunk_regions
):
def
_search_best_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
):
max_region_range
=
0
max_region_range
=
0
best_regions
=
None
best_region
=
None
for
i
in
possible_chunk_regions
:
while
len
(
possible_chunk_regions
)
>
0
:
if
i
[
"region"
][
1
]
-
i
[
"region"
][
0
]
>
max_region_range
:
for
i
in
possible_chunk_regions
:
best_regions
=
i
if
i
[
"region"
][
1
]
-
i
[
"region"
][
0
]
>
max_region_range
:
max_region_range
=
i
[
"region"
][
1
]
-
i
[
"region"
][
0
]
best_region
=
i
return
best_regions
max_region_range
=
i
[
"region"
][
1
]
-
i
[
"region"
][
0
]
if
self
.
_is_legal_region
(
best_region
,
chunk_infos
):
break
possible_chunk_regions
.
remove
(
i
)
max_region_range
=
0
best_region
=
None
return
best_region
def
_is_legal_region
(
self
,
cur_chunk_info
,
chunk_infos
):
(
chunk_region_start
,
chunk_region_end
)
=
cur_chunk_info
[
"region"
]
if
cur_chunk_info
in
chunk_infos
:
return
False
if
chunk_region_end
<
chunk_region_start
:
return
False
for
i
in
chunk_infos
:
region
=
i
[
"region"
]
if
not
((
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])):
return
False
return
True
def
_step_search
(
self
,
mem_peak
,
active_node
,
chunk_regions
):
def
_step_search
(
self
,
mem_peak
,
active_node
,
chunk_regions
):
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
max_chunk_region
=
self
.
_search_max_chunk_region
(
max_chunk_region
=
self
.
_search_max_chunk_region
(
...
@@ -1393,7 +1420,7 @@ class ChunkRegionSearch(object):
...
@@ -1393,7 +1420,7 @@ class ChunkRegionSearch(object):
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
max_chunk_region
,
peak_node
)
)
best_chunk_region
=
self
.
_search_best_chunk_region
(
possible_chunk_regions
)
best_chunk_region
=
self
.
_search_best_chunk_region
(
possible_
chunk_regions
,
chunk_regions
)
return
best_chunk_region
return
best_chunk_region
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
...
@@ -1919,5 +1946,5 @@ if CODEGEN_AVAILABLE:
...
@@ -1919,5 +1946,5 @@ if CODEGEN_AVAILABLE:
{
prologue
}
{
prologue
}
{
code
}
"""
{
code
}
"""
print
(
fn_code
)
#
print(fn_code)
return
PythonCode
(
fn_code
,
globals_
)
return
PythonCode
(
fn_code
,
globals_
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment