Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a68d240e
Commit
a68d240e
authored
Jan 09, 2023
by
oahzxl
Browse files
add doc for search chunk
parent
1951f7fa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
21 deletions
+55
-21
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+55
-21
No files found.
colossalai/autochunk/search_chunk.py
View file @
a68d240e
import
copy
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Tuple
from
torch.fx.node
import
Node
from
.estimate_memory
import
EstimateMemory
from
.reorder_graph
import
ReorderGraph
...
...
@@ -13,6 +16,34 @@ from .utils import (
class
SearchChunk
(
object
):
"""
This is the core class for AutoChunk.
It defines the framework of the strategy of AutoChunk.
Chunks will be selected one by one utill search stops.
The chunk search is as follows:
1. find the peak memory node
2. find the max chunk region according to the peak memory node
3. find all possible chunk regions in the max chunk region
4. find the best chunk region for current status
5. goto 1
Attributes:
gm: graph model
print_mem (bool): print estimated memory
trace_index: trace the flow of every dim of every node to find all free dims
trace_flow: determine the region chunk strategy
reorder_graph: reorder nodes to improve chunk efficiency
estimate_memory: estimate memory with chunk
select_chunk: select the best chunk region
Args:
gm: graph model
max_memory (int): max memory in MB
print_mem (bool): print estimated memory
"""
def
__init__
(
self
,
gm
,
max_memory
=
None
,
print_mem
=
False
)
->
None
:
self
.
gm
=
gm
self
.
print_mem
=
print_mem
...
...
@@ -33,24 +64,37 @@ class SearchChunk(object):
max_idx
=
mem_peak
.
index
(
max_value
)
return
max_idx
def
_get_free_var
(
self
):
def
_get_free_var_idx
(
self
)
->
List
:
"""
Get free var index
Returns:
free_var_idx (List): all indexs of free vars
"""
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
trace_index
.
node_list
):
if
n
.
op
==
"placeholder"
:
free_var_idx
.
append
(
idx
)
return
free_var_idx
def
_get_min_free_var
(
self
,
active_node_list
,
free_vars
):
min_len
=
999
for
idx
,
n
in
enumerate
(
active_node_list
):
if
idx
in
free_vars
:
continue
if
len
(
n
)
<
min_len
:
min_len
=
len
(
n
)
return
min_len
def
_search_max_chunk_region
(
self
,
active_node
:
List
,
peak_node
:
Node
,
chunk_regions
:
List
)
->
Tuple
:
"""
Search max chunk region according to peak memory node
Chunk region starts extending from the peak node, stops where free var num is min
def
_search_max_chunk_region
(
self
,
active_node
,
peak_node
,
chunk_regions
):
free_vars
=
self
.
_get_free_var
()
Args:
active_node (List): active node status for every node
peak_node (Node): peak memory node
chunk_regions (List): chunk region info
Returns:
chunk_region_start (int)
chunk_region_end (int)
"""
free_vars
=
self
.
_get_free_var_idx
()
free_var_num
=
len
(
free_vars
)
active_node_num
=
[
len
(
i
)
for
i
in
active_node
]
min_active_node_num
=
min
(
active_node_num
[
free_var_num
:])
...
...
@@ -92,16 +136,6 @@ class SearchChunk(object):
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
def
_is_not_compute
(
self
,
trace
,
chunk_range
,
dim_idx
):
if
trace
[
"idx"
][
dim_idx
]
not
in
trace
[
"compute"
]:
return
True
if
trace
[
"idx"
][
dim_idx
]
in
trace
[
"compute"
]
and
all
(
i
<
chunk_range
[
0
]
or
i
>
chunk_range
[
1
]
for
i
in
trace
[
"compute"
][
trace
[
"idx"
][
dim_idx
]]
):
return
True
return
False
def
_find_free_dim
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
):
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment