Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
065f0b4c
Commit
065f0b4c
authored
Jan 09, 2023
by
oahzxl
Browse files
add doc for search
parent
a68d240e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
8 deletions
+68
-8
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+68
-8
No files found.
colossalai/autochunk/search_chunk.py
View file @
065f0b4c
import
copy
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
from
torch.fx.node
import
Node
...
...
@@ -136,7 +136,24 @@ class SearchChunk(object):
chunk_region_end
=
region
[
0
]
-
1
return
chunk_region_start
,
chunk_region_end
def
_find_free_dim
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
):
def
_find_chunk_info
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
)
->
List
:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
trace_index
.
node_list
[
end_idx
]
...
...
@@ -174,7 +191,19 @@ class SearchChunk(object):
chunk_infos
.
append
(
chunk_info
)
return
chunk_infos
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
,
peak_node
):
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
:
Tuple
,
peak_node
:
Node
)
->
List
:
"""
Search every possible region within the max chunk region.
Args:
max_chunk_region (Tuple)
peak_node (Node): peak memory node
Returns:
possible_chunk_region (List)
"""
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
trace_index
.
idx_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
...
...
@@ -196,17 +225,39 @@ class SearchChunk(object):
continue
# select free dim
chunk_info
=
self
.
_find_
free_dim
(
chunk_info
=
self
.
_find_
chunk_info
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
if
len
(
chunk_info
)
>
0
:
possible_chunk_region
.
extend
(
chunk_info
)
return
possible_chunk_region
def
_step_search
(
self
,
mem_peak
,
active_node
,
chunk_regions
):
def
_step_search
(
self
,
mem_peak
:
List
[
float
],
active_node
:
List
[
List
[
Node
]],
chunk_infos
:
List
[
Dict
],
)
->
Dict
:
"""
Find one chunk region
The chunk search is as follows:
1. find the peak memory node
2. find the max chunk region according to the peak memory node
3. find all possible chunk regions in the max chunk region
4. find the best chunk region for current status
Args:
mem_peak (List): peak memory for every node
active_node (List[List[Node]]): active node for every node
chunk_infos (List[Dict]): all chunk info
Returns:
best_chunk_region (Dict)
"""
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
,
chunk_
region
s
active_node
,
peak_node
,
chunk_
info
s
)
if
max_chunk_region
==
None
:
return
None
...
...
@@ -214,7 +265,7 @@ class SearchChunk(object):
max_chunk_region
,
peak_node
)
best_chunk_region
=
self
.
select_chunk
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_
region
s
,
peak_node
,
max_chunk_region
,
mem_peak
possible_chunk_regions
,
chunk_
info
s
,
peak_node
,
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
reorder_graph
.
reorder_all
(
best_chunk_region
)
return
best_chunk_region
...
...
@@ -225,7 +276,16 @@ class SearchChunk(object):
return
True
return
False
def
search_region
(
self
):
def
search_region
(
self
)
->
Dict
:
"""
Search all chunk regions:
1. Estimate current memory
2. Find best chunk for current memory
3. goto 1
Returns:
chunk_infos (Dict)
"""
chunk_infos
=
[]
(
init_mem_peak
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment