Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7330d907
Commit
7330d907
authored
Dec 04, 2022
by
oahzxl
Browse files
add possible region search
parent
d9ca2f89
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
109 additions
and
7 deletions
+109
-7
chunk_codegen.py
chunk_codegen.py
+109
-7
No files found.
chunk_codegen.py
View file @
7330d907
...
...
@@ -356,7 +356,17 @@ class NodeIndexTracer(object):
"idx_to"
:
[
new_trace
[
i
]
for
i
in
dim_to
],
"dim_to"
:
dim_to
}
self
.
idx_view_list
.
append
(
view_dict
)
def
_merge_equal_idx
(
self
):
idx_equal
=
copy
.
deepcopy
(
self
.
idx_trace_equal
)
idx_equal
.
reverse
()
for
idx
in
idx_equal
:
merge_to
=
min
(
idx
)
merge_from
=
max
(
idx
)
for
trace
in
self
.
idx_trace_list
:
if
merge_from
in
trace
[
'idx'
]:
trace
[
'idx'
]
=
[
merge_to
if
i
==
merge_from
else
i
for
i
in
trace
[
'idx'
]]
def
trace_node_idx
(
self
):
for
idx
,
node
in
enumerate
(
self
.
nodes_list
):
if
node
.
op
==
'placeholder'
:
...
...
@@ -396,6 +406,7 @@ class NodeIndexTracer(object):
continue
else
:
raise
NotImplementedError
(
node
.
op
,
"op not implemented yet!"
)
self
.
_merge_equal_idx
()
class
MemoryEstimator
(
object
):
...
...
@@ -433,6 +444,8 @@ class MemoryEstimator(object):
for
i
in
range
(
len
(
out_node
)):
if
out_node
[
i
][
0
]
>
0
:
delete_node
.
append
(
out_node
[
i
][
1
][
0
])
elif
nodes_to_delete
[
i
].
op
==
'placeholder'
:
delete_node
.
append
(
nodes_to_delete
[
i
].
name
)
return
delete_size
,
delete_node
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
):
...
...
@@ -516,8 +529,9 @@ class MemoryEstimator(object):
active_node_list_log
=
[]
not_contiguous_list
=
[]
node_list
=
list
(
gm
.
graph
.
nodes
)
user_to_last_uses
=
self
.
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
_delete_free_var_from_last_use
(
user_to_last_uses
)
user_to_last_uses
=
self
.
_get_last_usr
(
node_list
)
user_to_last_uses_no_free_var
=
self
.
_get_last_usr
(
node_list
)
_delete_free_var_from_last_use
(
user_to_last_uses_no_free_var
)
use_chunk
=
all
(
i
is
not
None
for
i
in
[
start_nodes
,
end_nodes
,
chunk_dims
,
chunk_sizes
])
chunk_within
=
False
...
...
@@ -535,6 +549,7 @@ class MemoryEstimator(object):
if
node
.
op
==
'placeholder'
:
act_memory
+=
self
.
_get_meta_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
active_node_list
.
append
(
node
.
name
)
# skip output
elif
node
.
op
==
'output'
:
continue
...
...
@@ -549,10 +564,10 @@ class MemoryEstimator(object):
act_memory
-=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
)
if
chunk_within
:
act_memory
-=
self
.
_get_chunk_delete_node_size
(
node
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
node
,
user_to_last_uses
_no_free_var
,
chunk_ratio
,
node_list
,
start_nodes
[
chunk_region_idx
],
end_nodes
[
chunk_region_idx
])
/
(
1024
**
2
)
else
:
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses
_no_free_var
)
/
(
1024
**
2
)
# log active node
self
.
_add_active_node
(
node
,
active_node_list
)
...
...
@@ -572,8 +587,92 @@ class MemoryEstimator(object):
self
.
_print_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
self
.
_print_mem_log
(
act_memory_after_node_log
,
node_list
,
"after"
)
param_memory
=
parameter_size
(
gm
)
return
act_memory
+
param_memory
,
param_memory
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return
act_memory_peak_log
,
act_memory_after_node_log
,
active_node_list_log
class
ChunkRegionSearch
(
object
):
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
node_list
=
list
(
gm
.
graph
.
nodes
)
self
.
memory_estimator
=
MemoryEstimator
()
self
.
index_tracer
=
NodeIndexTracer
(
gm
)
self
.
index_tracer
.
trace_node_idx
()
def
_find_peak_node
(
self
,
mem_peak
):
max_value
=
max
(
mem_peak
)
max_idx
=
[
mem_peak
.
index
(
max_value
)]
return
max_idx
def
_get_free_var
(
self
):
free_var_idx
=
[]
for
idx
,
n
in
enumerate
(
self
.
node_list
):
if
n
.
op
==
'placeholder'
:
free_var_idx
.
append
(
idx
)
return
free_var_idx
def
_get_min_free_var
(
self
,
active_node_list
,
free_vars
):
min_len
=
999
for
idx
,
n
in
enumerate
(
active_node_list
):
if
idx
in
free_vars
:
continue
if
len
(
n
)
<
min_len
:
min_len
=
len
(
n
)
return
min_len
def
_search_max_chunk_region
(
self
,
active_node
,
peak_node
):
free_vars
=
self
.
_get_free_var
()
min_var
=
self
.
_get_min_free_var
(
active_node
,
free_vars
)
# from peak_node to free_var
chunk_region_start
=
None
for
i
in
range
(
peak_node
,
-
1
,
-
1
):
if
len
(
active_node
[
i
])
==
min_var
:
chunk_region_start
=
i
+
1
break
if
i
in
free_vars
or
i
==
0
:
raise
RuntimeError
()
# from peak_node to len-2
chunk_region_end
=
None
for
i
in
range
(
peak_node
,
len
(
active_node
)
-
1
):
if
len
(
active_node
[
i
])
==
min_var
:
chunk_region_end
=
i
-
1
break
if
i
in
free_vars
or
i
==
0
:
raise
RuntimeError
()
return
chunk_region_start
,
chunk_region_end
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
,
peak_node
):
possible_chunk_region
=
[]
for
before_idx
in
range
(
max_chunk_region
[
0
],
peak_node
):
for
after_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]):
# skip non compute nodes
if
any
(
op
in
[
'placeholder'
,
'get_attr'
,
'output'
]
for
op
in
[
self
.
node_list
[
before_idx
].
op
,
self
.
node_list
[
after_idx
].
op
]):
continue
if
any
(
any
(
i
in
name
for
i
in
[
'getitem'
,
'getattr'
])
for
name
in
[
self
.
node_list
[
before_idx
].
name
,
self
.
node_list
[
after_idx
].
name
]):
continue
# select free dim
before_trace
=
self
.
index_tracer
.
idx_trace_list
[
before_idx
]
after_trace
=
self
.
index_tracer
.
idx_trace_list
[
after_idx
]
free_dim
=
[]
for
i
in
range
(
min
(
len
(
before_trace
[
'idx'
]),
len
(
after_trace
[
'idx'
]))):
if
(
before_trace
[
'idx'
][
i
]
==
after_trace
[
'idx'
][
i
]
and
before_trace
[
'idx'
][
i
]
not
in
before_trace
[
'compute'
]
and
after_trace
[
'idx'
][
i
]
not
in
after_trace
[
'compute'
]):
free_dim
.
append
(
i
)
possible_chunk_region
.
append
({
'region'
:
(
before_idx
,
after_idx
),
'dim'
:
free_dim
})
return
possible_chunk_region
def
search_region
(
self
):
mem_peak
,
mem_after
,
active_node
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
gm
)
peak_nodes
=
self
.
_find_peak_node
(
mem_peak
)
for
idx
,
peak_node
in
enumerate
(
peak_nodes
):
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
)
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
...
...
@@ -696,6 +795,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
node_index_tracer
=
NodeIndexTracer
(
meta_graph
)
node_index_tracer
.
trace_node_idx
()
chunk_region_search
=
ChunkRegionSearch
(
meta_graph
)
chunk_region_search
.
search_region
()
# find the input and output var names for each offload region
for
idx
,
(
start
,
end
)
in
enumerate
(
chunk_regions
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment