Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c3a2bf48
Commit
c3a2bf48
authored
Jan 06, 2023
by
oahzxl
Browse files
code style
parent
a6cdbf91
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
36 deletions
+46
-36
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+6
-8
colossalai/autochunk/reorder_graph.py
colossalai/autochunk/reorder_graph.py
+18
-15
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+7
-4
colossalai/autochunk/select_chunk.py
colossalai/autochunk/select_chunk.py
+6
-6
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+9
-3
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
c3a2bf48
...
...
@@ -103,7 +103,7 @@ def emit_code_with_chunk(
nodes
,
emit_node_func
,
delete_unused_value_func
,
chunk_region_
search
:
SearchChunk
,
search
_chunk
:
SearchChunk
,
chunk_infos
,
):
"""Emit code with nested activation checkpoint
...
...
@@ -133,7 +133,7 @@ def emit_code_with_chunk(
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_infos
]
node_list
=
chunk_region_
search
.
reorder_graph
.
reorder_node_list
(
node_list
)
node_list
=
search
_chunk
.
reorder_graph
.
reorder_node_list
(
node_list
)
node_idx
=
0
region_idx
=
0
within_chunk_region
=
False
...
...
@@ -167,7 +167,7 @@ def emit_code_with_chunk(
)
# ones like
if
"ones_like"
in
node
.
name
:
meta_node
=
chunk_region_
search
.
trace_index
.
node_list
[
node_idx
]
meta_node
=
search
_chunk
.
trace_index
.
node_list
[
node_idx
]
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
...
...
@@ -220,10 +220,8 @@ if CODEGEN_AVAILABLE:
self
.
max_memory
=
max_memory
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
# find the chunk regions
self
.
chunk_region_search
=
SearchChunk
(
meta_graph
,
max_memory
,
print_mem
)
self
.
chunk_infos
=
self
.
chunk_region_search
.
search_region
()
self
.
search_chunk
=
SearchChunk
(
meta_graph
,
max_memory
,
print_mem
)
self
.
chunk_infos
=
self
.
search_chunk
.
search_region
()
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
...
...
@@ -458,7 +456,7 @@ if CODEGEN_AVAILABLE:
nodes
,
emit_node
,
delete_unused_values
,
self
.
chunk_region_
search
,
self
.
search
_chunk
,
self
.
chunk_infos
,
)
...
...
colossalai/autochunk/reorder_graph.py
View file @
c3a2bf48
...
...
@@ -3,28 +3,31 @@ from .utils import find_idx_by_name
class
ReorderGraph
(
object
):
def
__init__
(
self
,
index_tracer
:
TraceIndex
)
->
None
:
self
.
index_tracer
=
index_tracer
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
index_tracer
.
idx_trace_list
))}
def
__init__
(
self
,
trace_index
:
TraceIndex
)
->
None
:
self
.
trace_index
=
trace_index
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
trace_index
.
idx_trace_list
))
}
def
_get_reorder_map
(
self
,
chunk_info
):
reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
index_tracer
.
node_list
))}
reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
trace_index
.
node_list
))}
chunk_region_start
=
chunk_info
[
"region"
][
0
]
chunk_region_end
=
chunk_info
[
"region"
][
1
]
chunk_prepose_nodes
=
chunk_info
[
"args"
][
"prepose_nodes"
]
chunk_prepose_nodes_idx
=
[
find_idx_by_name
(
i
.
name
,
self
.
index_tracer
.
node_list
)
for
i
in
chunk_prepose_nodes
find_idx_by_name
(
i
.
name
,
self
.
trace_index
.
node_list
)
for
i
in
chunk_prepose_nodes
]
# put prepose nodes ahead
for
idx
,
n
in
enumerate
(
chunk_prepose_nodes
):
n_idx
=
chunk_prepose_nodes_idx
[
idx
]
reorder_map
[
n_idx
]
=
chunk_region_start
+
idx
# put other nodes after prepose nodes
for
n
in
self
.
index_tracer
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
for
n
in
self
.
trace_index
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
if
n
in
chunk_prepose_nodes
:
continue
n_idx
=
find_idx_by_name
(
n
.
name
,
self
.
index_tracer
.
node_list
)
n_idx
=
find_idx_by_name
(
n
.
name
,
self
.
trace_index
.
node_list
)
pos
=
sum
([
n_idx
<
i
for
i
in
chunk_prepose_nodes_idx
])
reorder_map
[
n_idx
]
=
n_idx
+
pos
...
...
@@ -50,25 +53,25 @@ class ReorderGraph(object):
self
.
all_reorder_map
[
origin_idx
]
=
reorder_map
[
map_idx
]
def
_reorder_self_node_list
(
self
,
reorder_map
):
new_node_list
=
[
None
for
_
in
range
(
len
(
self
.
index_tracer
.
node_list
))]
new_node_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_index
.
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
self
.
index_tracer
.
node_list
[
old_idx
]
self
.
index_tracer
.
node_list
=
new_node_list
new_node_list
[
new_idx
]
=
self
.
trace_index
.
node_list
[
old_idx
]
self
.
trace_index
.
node_list
=
new_node_list
def
_reorder_idx_trace
(
self
,
reorder_map
):
# reorder list
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
index_tracer
.
idx_trace_list
))]
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_index
.
idx_trace_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_idx_trace_list
[
new_idx
]
=
self
.
index_tracer
.
idx_trace_list
[
old_idx
]
self
.
index_tracer
.
idx_trace_list
=
new_idx_trace_list
new_idx_trace_list
[
new_idx
]
=
self
.
trace_index
.
idx_trace_list
[
old_idx
]
self
.
trace_index
.
idx_trace_list
=
new_idx_trace_list
# update compute
for
idx_trace
in
self
.
index_tracer
.
idx_trace_list
:
for
idx_trace
in
self
.
trace_index
.
idx_trace_list
:
compute
=
idx_trace
[
"compute"
]
for
dim_compute
in
compute
:
for
idx
,
i
in
enumerate
(
dim_compute
):
dim_compute
[
idx
]
=
reorder_map
[
i
]
# update source
for
idx_trace
in
self
.
index_tracer
.
idx_trace_list
:
for
idx_trace
in
self
.
trace_index
.
idx_trace_list
:
source
=
idx_trace
[
"source"
]
for
dim_idx
,
dim_source
in
enumerate
(
source
):
new_dim_source
=
{}
...
...
colossalai/autochunk/search_chunk.py
View file @
c3a2bf48
import
copy
from
.select_chunk
import
SelectChunk
from
.trace_index
import
TraceIndex
from
.reorder_graph
import
ReorderGraph
from
.estiamte_memory
import
EstimateMemory
from
.reorder_graph
import
ReorderGraph
from
.select_chunk
import
SelectChunk
from
.trace_flow
import
TraceFlow
from
.trace_index
import
TraceIndex
from
.utils
import
(
get_node_shape
,
is_non_compute_node
,
...
...
@@ -22,7 +22,10 @@ class SearchChunk(object):
self
.
reorder_graph
=
ReorderGraph
(
self
.
trace_index
)
self
.
estimate_memory
=
EstimateMemory
()
self
.
select_chunk
=
SelectChunk
(
self
.
trace_index
,
self
.
estimate_memory
,
self
.
reorder_graph
,
max_memory
=
max_memory
self
.
trace_index
,
self
.
estimate_memory
,
self
.
reorder_graph
,
max_memory
=
max_memory
,
)
def
_find_peak_node
(
self
,
mem_peak
):
...
...
colossalai/autochunk/select_chunk.py
View file @
c3a2bf48
from
.trace_index
import
TraceIndex
from
.reorder_graph
import
ReorderGraph
from
.estiamte_memory
import
EstimateMemory
from
.reorder_graph
import
ReorderGraph
from
.trace_index
import
TraceIndex
from
.utils
import
is_non_compute_node
class
SelectChunk
(
object
):
def
__init__
(
self
,
index_tracer
:
TraceIndex
,
memory_
estimator
:
EstimateMemory
,
trace_index
:
TraceIndex
,
estimat
e_mem
or
y
:
EstimateMemory
,
reorder_graph
:
ReorderGraph
,
max_memory
=
None
,
):
self
.
index_tracer
=
index_tracer
self
.
memory_estimator
=
memory_
estimator
self
.
index_tracer
=
trace_index
self
.
memory_estimator
=
estimat
e_mem
or
y
self
.
reorder_graph
=
reorder_graph
if
max_memory
is
not
None
:
self
.
stratge
=
"fit_memory"
...
...
colossalai/autochunk/trace_flow.py
View file @
c3a2bf48
...
...
@@ -81,7 +81,9 @@ class TraceFlow(object):
input_dim_after_node
=
{}
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
trace_index
.
node_list
[
k
])
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
trace_index
.
node_list
[
k
]
)
if
inherit_dim
:
input_dim_after_node
[
k
]
=
inherit_dim
...
...
@@ -217,7 +219,9 @@ class TraceFlow(object):
for
arg
in
arg_list
:
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_index
.
node_list
)
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_index
.
node_list
)
<
end_idx
):
continue
...
...
@@ -255,7 +259,9 @@ class TraceFlow(object):
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
if
chunk_dim
is
not
None
:
user_source
=
self
.
trace_index
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
user_source
=
self
.
trace_index
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
if
input_node_idx
in
user_source
:
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment