Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c3d72f7d
Commit
c3d72f7d
authored
Jan 06, 2023
by
oahzxl
Browse files
seperate reorder
parent
6685a9d0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
21 deletions
+29
-21
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+2
-2
colossalai/autochunk/chunk_region_search.py
colossalai/autochunk/chunk_region_search.py
+4
-3
colossalai/autochunk/chunk_selector.py
colossalai/autochunk/chunk_selector.py
+5
-3
colossalai/autochunk/index_tracer.py
colossalai/autochunk/index_tracer.py
+18
-13
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
c3d72f7d
...
...
@@ -103,7 +103,7 @@ def emit_code_with_chunk(
nodes
,
emit_node_func
,
delete_unused_value_func
,
chunk_region_search
,
chunk_region_search
:
ChunkRegionSearch
,
chunk_infos
,
):
"""Emit code with nested activation checkpoint
...
...
@@ -133,7 +133,7 @@ def emit_code_with_chunk(
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_infos
]
node_list
=
chunk_region_search
.
index_tracer
.
reorder_node_list
(
node_list
)
node_list
=
chunk_region_search
.
reorder_graph
.
reorder_node_list
(
node_list
)
node_idx
=
0
region_idx
=
0
within_chunk_region
=
False
...
...
colossalai/autochunk/chunk_region_search.py
View file @
c3d72f7d
import
copy
from
.chunk_selector
import
ChunkSelector
from
.index_tracer
import
IndexTracer
from
.index_tracer
import
IndexTracer
,
ReorderGraph
from
.memory_estiamtor
import
MemoryEstimator
from
.utils
import
(
get_node_shape
,
...
...
@@ -16,9 +16,10 @@ class ChunkRegionSearch(object):
self
.
print_mem
=
print_mem
self
.
index_tracer
=
IndexTracer
(
list
(
gm
.
graph
.
nodes
))
self
.
index_tracer
.
trace_index
()
self
.
reorder_graph
=
ReorderGraph
(
self
.
index_tracer
)
self
.
memory_estimator
=
MemoryEstimator
()
self
.
chunk_selector
=
ChunkSelector
(
self
.
index_tracer
,
self
.
memory_estimator
,
max_memory
=
max_memory
self
.
index_tracer
,
self
.
memory_estimator
,
self
.
reorder_graph
,
max_memory
=
max_memory
)
def
_find_peak_node
(
self
,
mem_peak
):
...
...
@@ -175,7 +176,7 @@ class ChunkRegionSearch(object):
best_chunk_region
=
self
.
chunk_selector
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
,
peak_node
,
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
index_tracer
.
reorder_all
(
best_chunk_region
)
best_chunk_region
=
self
.
reorder_graph
.
reorder_all
(
best_chunk_region
)
return
best_chunk_region
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
...
...
colossalai/autochunk/chunk_selector.py
View file @
c3d72f7d
from
.index_tracer
import
IndexTracer
from
.index_tracer
import
IndexTracer
,
ReorderGraph
from
.memory_estiamtor
import
MemoryEstimator
from
.utils
import
is_non_compute_node
...
...
@@ -8,10 +8,12 @@ class ChunkSelector(object):
self
,
index_tracer
:
IndexTracer
,
memory_estimator
:
MemoryEstimator
,
reorder_graph
:
ReorderGraph
,
max_memory
=
None
,
):
self
.
index_tracer
=
index_tracer
self
.
memory_estimator
=
memory_estimator
self
.
reorder_graph
=
reorder_graph
if
max_memory
is
not
None
:
self
.
stratge
=
"fit_memory"
self
.
max_memory
=
max_memory
# MB
...
...
@@ -64,7 +66,7 @@ class ChunkSelector(object):
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
...
...
@@ -174,7 +176,7 @@ class ChunkSelector(object):
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
cur_node_list
,
cur_region
=
self
.
reorder_graph
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
...
...
colossalai/autochunk/index_tracer.py
View file @
c3d72f7d
...
...
@@ -17,7 +17,6 @@ class IndexTracer(object):
self
.
idx_trace_equal
=
[]
self
.
idx_view_list
=
{}
self
.
idx_count
=
-
1
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
idx_trace_list
))}
def
_init_idx_trace_list
(
self
):
idx_trace_list
=
[]
...
...
@@ -981,24 +980,30 @@ class IndexTracer(object):
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
class
ReorderGraph
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
self
.
index_tracer
=
index_tracer
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
index_tracer
.
idx_trace_list
))}
def
_get_reorder_map
(
self
,
chunk_info
):
reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
node_list
))}
reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
index_tracer
.
node_list
))}
chunk_region_start
=
chunk_info
[
"region"
][
0
]
chunk_region_end
=
chunk_info
[
"region"
][
1
]
chunk_prepose_nodes
=
chunk_info
[
"args"
][
"prepose_nodes"
]
chunk_prepose_nodes_idx
=
[
find_idx_by_name
(
i
.
name
,
self
.
node_list
)
for
i
in
chunk_prepose_nodes
find_idx_by_name
(
i
.
name
,
self
.
index_tracer
.
node_list
)
for
i
in
chunk_prepose_nodes
]
# put prepose nodes ahead
for
idx
,
n
in
enumerate
(
chunk_prepose_nodes
):
n_idx
=
chunk_prepose_nodes_idx
[
idx
]
reorder_map
[
n_idx
]
=
chunk_region_start
+
idx
# put other nodes after prepose nodes
for
n
in
self
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
for
n
in
self
.
index_tracer
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
if
n
in
chunk_prepose_nodes
:
continue
n_idx
=
find_idx_by_name
(
n
.
name
,
self
.
node_list
)
n_idx
=
find_idx_by_name
(
n
.
name
,
self
.
index_tracer
.
node_list
)
pos
=
sum
([
n_idx
<
i
for
i
in
chunk_prepose_nodes_idx
])
reorder_map
[
n_idx
]
=
n_idx
+
pos
...
...
@@ -1024,25 +1029,25 @@ class IndexTracer(object):
self
.
all_reorder_map
[
origin_idx
]
=
reorder_map
[
map_idx
]
def
_reorder_self_node_list
(
self
,
reorder_map
):
new_node_list
=
[
None
for
_
in
range
(
len
(
self
.
node_list
))]
new_node_list
=
[
None
for
_
in
range
(
len
(
self
.
index_tracer
.
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
self
.
node_list
[
old_idx
]
self
.
node_list
=
new_node_list
new_node_list
[
new_idx
]
=
self
.
index_tracer
.
node_list
[
old_idx
]
self
.
index_tracer
.
node_list
=
new_node_list
def
_reorder_idx_trace
(
self
,
reorder_map
):
# reorder list
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
idx_trace_list
))]
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
index_tracer
.
idx_trace_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_idx_trace_list
[
new_idx
]
=
self
.
idx_trace_list
[
old_idx
]
self
.
idx_trace_list
=
new_idx_trace_list
new_idx_trace_list
[
new_idx
]
=
self
.
index_tracer
.
idx_trace_list
[
old_idx
]
self
.
index_tracer
.
idx_trace_list
=
new_idx_trace_list
# update compute
for
idx_trace
in
self
.
idx_trace_list
:
for
idx_trace
in
self
.
index_tracer
.
idx_trace_list
:
compute
=
idx_trace
[
"compute"
]
for
dim_compute
in
compute
:
for
idx
,
i
in
enumerate
(
dim_compute
):
dim_compute
[
idx
]
=
reorder_map
[
i
]
# update source
for
idx_trace
in
self
.
idx_trace_list
:
for
idx_trace
in
self
.
index_tracer
.
idx_trace_list
:
source
=
idx_trace
[
"source"
]
for
dim_idx
,
dim_source
in
enumerate
(
source
):
new_dim_source
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment