Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
51ef8384
Commit
51ef8384
authored
Dec 23, 2022
by
oahzxl
Browse files
finish node reorder
parent
884a228e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
16 deletions
+15
-16
chunk_codegen.py
chunk_codegen.py
+15
-16
No files found.
chunk_codegen.py
View file @
51ef8384
...
...
@@ -1238,7 +1238,7 @@ class MemoryEstimator(object):
def
estimate_chunk_inference_mem
(
self
,
gm
:
torch
.
fx
.
GraphModule
,
node_list
,
chunk_infos
=
None
,
):
act_memory
=
0.0
...
...
@@ -1247,7 +1247,6 @@ class MemoryEstimator(object):
active_node_list
=
[]
active_node_list_log
=
[]
not_contiguous_list
=
[]
node_list
=
list
(
gm
.
graph
.
nodes
)
user_to_last_uses
=
self
.
_get_last_usr
(
node_list
)
user_to_last_uses_no_free_var
=
self
.
_get_last_usr
(
node_list
)
_delete_free_var_from_last_use
(
user_to_last_uses_no_free_var
)
...
...
@@ -1281,7 +1280,6 @@ class MemoryEstimator(object):
)
/
(
1024
**
2
)
# determine chunk ratio for current node
# TODO: adapt to prepose node memory
if
chunk_within
:
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
...
...
@@ -1371,10 +1369,7 @@ class MemoryEstimator(object):
class
ChunkRegionSearch
(
object
):
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
node_list
=
list
(
gm
.
graph
.
nodes
)
self
.
index_tracer
=
IndexTracer
(
self
.
node_list
)
# node list shared in index tracer
self
.
index_tracer
=
IndexTracer
(
list
(
gm
.
graph
.
nodes
))
self
.
index_tracer
.
trace_index
()
self
.
memory_estimator
=
MemoryEstimator
(
self
.
index_tracer
)
...
...
@@ -1385,7 +1380,7 @@ class ChunkRegionSearch(object):
def
_get_free_var
(
self
):
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
node_list
):
for
idx
,
n
in
enumerate
(
self
.
index_tracer
.
node_list
):
if
n
.
op
==
"placeholder"
:
free_var_idx
.
append
(
idx
)
return
free_var_idx
...
...
@@ -1455,13 +1450,13 @@ class ChunkRegionSearch(object):
def
_find_free_dim
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
):
start_traces
=
input_trace
[
start_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
node_list
[
end_idx
]
end_node
=
self
.
index_tracer
.
node_list
[
end_idx
]
chunk_infos
=
[]
for
end_dim
,
end_trace_idx
in
enumerate
(
end_trace
[
"idx"
]):
for
end_dim
,
_
in
enumerate
(
end_trace
[
"idx"
]):
if
len
(
start_traces
)
>
1
:
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
for
start_dim
,
_
in
enumerate
(
start_trace
[
"idx"
]):
# dim size cannot be 1
if
(
_get_node_shape
(
end_node
)[
end_dim
]
==
1
...
...
@@ -1494,7 +1489,7 @@ class ChunkRegionSearch(object):
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
index_tracer
.
idx_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
node_list
):
for
_
,
n
in
enumerate
(
self
.
index_tracer
.
node_list
):
cur_trace
=
{}
for
arg
in
n
.
args
:
if
type
(
arg
)
==
type
(
n
)
and
not
_is_non_compute_node_except_placeholder
(
...
...
@@ -1507,8 +1502,8 @@ class ChunkRegionSearch(object):
for
end_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
if
_is_non_compute_node
(
self
.
node_list
[
start_idx
]
)
or
_is_non_compute_node
(
self
.
node_list
[
end_idx
]):
self
.
index_tracer
.
node_list
[
start_idx
]
)
or
_is_non_compute_node
(
self
.
index_tracer
.
node_list
[
end_idx
]):
continue
# select free dim
...
...
@@ -1577,7 +1572,9 @@ class ChunkRegionSearch(object):
init_mem_peak
,
_
,
active_node
,
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
gm
)
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
)
mem_peak
=
init_mem_peak
while
True
:
...
...
@@ -1590,7 +1587,9 @@ class ChunkRegionSearch(object):
mem_peak
,
_
,
active_node
,
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
gm
,
chunk_infos
)
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
chunk_infos
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
return
chunk_infos
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment