Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8f5a0edf
Commit
8f5a0edf
authored
Dec 26, 2022
by
oahzxl
Browse files
add chunk select
parent
1b8a0665
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
112 additions
and
35 deletions
+112
-35
chunk_codegen.py
chunk_codegen.py
+112
-35
No files found.
chunk_codegen.py
View file @
8f5a0edf
...
...
@@ -69,7 +69,7 @@ class IndexTracer(object):
self
.
node_list
=
node_list
self
.
idx_trace_list
=
self
.
_init_idx_trace_list
()
self
.
idx_trace_equal
=
[]
self
.
idx_view_list
=
[]
self
.
idx_view_list
=
{}
self
.
idx_count
=
-
1
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
idx_trace_list
))}
...
...
@@ -576,7 +576,7 @@ class IndexTracer(object):
"idx_to"
:
[
self
.
idx_trace_list
[
node_idx
][
"idx"
][
i
]
for
i
in
dim_to
],
"dim_to"
:
dim_to
,
}
self
.
idx_view_list
.
append
(
view_dict
)
self
.
idx_view_list
[
node
]
=
view_dict
def
_merge_equal_idx
(
self
):
idx_equal
=
copy
.
deepcopy
(
self
.
idx_trace_equal
)
...
...
@@ -702,7 +702,7 @@ class IndexTracer(object):
for
node_dim
in
range
(
len
(
_get_node_shape
(
node
))):
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
and
input_dim
in
node_trace_source
[
node_dim
][
input_node_idx
]
and
input_dim
[
0
]
in
node_trace_source
[
node_dim
][
input_node_idx
]
):
return
node_dim
return
None
...
...
@@ -875,6 +875,7 @@ class IndexTracer(object):
remove_inputs
=
[]
for
input_node
in
inputs
:
input_dict
=
{}
input_node_idx
=
_find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
for
user
in
input_node
.
users
.
keys
():
if
_is_non_compute_node
(
user
):
continue
...
...
@@ -882,7 +883,11 @@ class IndexTracer(object):
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
if
chunk_dim
is
not
None
:
input_dict
[
user_idx
]
=
chunk_dim
user_source
=
self
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
if
input_node_idx
in
user_source
:
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
else
:
return
None
if
len
(
input_dict
)
==
0
:
remove_inputs
.
append
(
input_node
)
else
:
...
...
@@ -898,6 +903,7 @@ class IndexTracer(object):
"inputs_dim"
:
inputs_dim
,
"outputs"
:
outputs
,
"outputs_dim"
:
end_dim
,
"node_chunk_dim"
:
all_node_info
,
"args"
:
{},
}
...
...
@@ -974,6 +980,26 @@ class IndexTracer(object):
if
i
not
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
# reassgin reshape size, some size may have changed due to chunk
chunk_info
=
self
.
_reassgin_reshape_size
(
chunk_info
)
return
chunk_info
def
_reassgin_reshape_size
(
self
,
chunk_info
):
chunk_region
=
chunk_info
[
'region'
]
reshape_size
=
{}
for
node
in
self
.
node_list
[
chunk_region
[
0
]:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
'reshape'
,
'view'
]):
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
idx_view_list
[
node
]
chunk_dim
=
chunk_info
[
'node_chunk_dim'
][
node
][
'chunk_dim'
]
reshape_size
[
node
.
name
]
=
{}
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
if
reshape_arg_dim
in
reshape_log
[
'dim_to'
]:
continue
if
reshape_arg_dim
==
chunk_dim
:
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
"chunk_size"
chunk_info
[
'reshape_size'
]
=
reshape_size
return
chunk_info
def
_get_reorder_map
(
self
,
chunk_info
):
...
...
@@ -1183,23 +1209,15 @@ class MemoryEstimator(object):
not_contiguous_list
.
append
(
node
)
return
mem
def
_get_chunk_ratio
(
self
,
node
,
chunk_inputs
,
chunk_inputs_dim
,
chunk_size
):
def
_get_chunk_ratio
(
self
,
node
,
chunk_node_dim
,
chunk_size
):
if
node
not
in
chunk_node_dim
:
return
1.0
node_shape
=
_get_node_shape
(
node
)
node_source
=
self
.
index_tracer
.
_find_source_trace_from_node
(
node
)
for
(
input_node
,
input_node_dim
)
in
zip
(
chunk_inputs
,
chunk_inputs_dim
):
for
k
,
v
in
input_node_dim
.
items
():
# TODO: inherit dim should be list too, int now
inherit_dim
=
self
.
index_tracer
.
_find_inherit_dim
(
input_node
,
v
,
self
.
index_tracer
.
node_list
[
k
]
)
if
k
==
_find_idx_by_name
(
node
.
name
,
self
.
index_tracer
.
node_list
):
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
inherit_dim
]
return
chunk_ratio
for
dim
,
source
in
enumerate
(
node_source
):
if
k
in
source
and
inherit_dim
in
source
[
k
]:
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
dim
]
return
chunk_ratio
chunk_dim
=
chunk_node_dim
[
node
][
'chunk_dim'
]
if
chunk_dim
is
None
:
return
1.0
else
:
return
float
(
chunk_size
)
/
node_shape
[
chunk_dim
]
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
...
...
@@ -1242,6 +1260,7 @@ class MemoryEstimator(object):
self
,
node_list
,
chunk_infos
=
None
,
print_mem
=
False
,
):
act_memory
=
0.0
act_memory_peak_log
=
[]
...
...
@@ -1271,6 +1290,7 @@ class MemoryEstimator(object):
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_node_dim
=
[
i
[
"node_chunk_dim"
]
for
i
in
chunk_infos
]
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
...
...
@@ -1285,8 +1305,7 @@ class MemoryEstimator(object):
if
chunk_within
:
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_dim
[
chunk_region_idx
],
chunk_node_dim
[
chunk_region_idx
],
chunk_size
,
)
...
...
@@ -1357,6 +1376,7 @@ class MemoryEstimator(object):
act_memory_after_node_log
.
append
(
act_memory
)
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
if
print_mem
:
print
(
"with chunk"
if
use_chunk
else
"without chunk"
)
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
...
...
@@ -1369,21 +1389,70 @@ class MemoryEstimator(object):
class
ChunkSelector
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
,
stratge
)
->
None
:
def
__init__
(
self
,
index_tracer
:
IndexTracer
,
memory_estimator
:
MemoryEstimator
,
stratge
):
self
.
index_tracer
=
index_tracer
self
.
memory_estimator
=
memory_estimator
assert
stratge
in
[
'min_memory'
,
'fit_memory'
]
self
.
stratge
=
stratge
self
.
max_memory
=
8
00
# MB
self
.
max_memory
=
6
00
# MB
def
_select_best_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
):
def
_select_best_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
if
self
.
stratge
==
'min_memory'
:
best_region
=
self
.
_select_min_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
)
elif
self
.
stratge
==
'fit_memory'
:
pass
best_region
=
self
.
_select_fit_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
)
else
:
raise
RuntimeError
()
return
best_region
def
_select_fit_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# stop chunk if max memory satisfy memory limit
if
max
(
mem_peak
)
<
self
.
max_memory
:
return
None
# remove illegal regions
illegal_regions
=
[]
for
i
in
possible_chunk_regions
:
if
not
self
.
_is_legal_region
(
i
,
chunk_infos
):
illegal_regions
.
append
(
i
)
for
i
in
illegal_regions
:
if
i
in
possible_chunk_regions
:
possible_chunk_regions
.
remove
(
i
)
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_chunk_infos
=
chunk_infos
+
[
region
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
if
cur_chunk_region_max_peak
<
self
.
max_memory
:
regions_dict
.
append
({
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
'region'
][
0
],
region
[
'region'
][
1
]),
})
# no region found
if
len
(
regions_dict
)
==
0
:
return
None
# select the min chunk len
chunk_len
=
[
i
[
"chunk_len"
]
for
i
in
regions_dict
]
best_region_idx
=
chunk_len
.
index
(
min
(
chunk_len
))
best_region
=
regions_dict
[
best_region_idx
][
"chunk_info"
]
return
best_region
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
for
i
in
self
.
index_tracer
.
node_list
[
start
:
end
+
1
]:
if
_is_non_compute_node
(
i
):
count
+=
1
return
count
def
_select_min_memory_chunk_region
(
self
,
possible_chunk_regions
,
chunk_infos
):
max_region_range
=
0
best_region
=
None
...
...
@@ -1421,7 +1490,7 @@ class ChunkRegionSearch(object):
self
.
index_tracer
=
IndexTracer
(
list
(
gm
.
graph
.
nodes
))
self
.
index_tracer
.
trace_index
()
self
.
memory_estimator
=
MemoryEstimator
(
self
.
index_tracer
)
self
.
chunk_selector
=
ChunkSelector
(
self
.
index_tracer
,
stratge
=
"
min
_memory"
)
self
.
chunk_selector
=
ChunkSelector
(
self
.
index_tracer
,
self
.
memory_estimator
,
stratge
=
"
fit
_memory"
)
def
_find_peak_node
(
self
,
mem_peak
):
max_value
=
max
(
mem_peak
)
...
...
@@ -1575,7 +1644,7 @@ class ChunkRegionSearch(object):
max_chunk_region
,
peak_node
)
best_chunk_region
=
self
.
chunk_selector
.
_select_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
possible_chunk_regions
,
chunk_regions
,
peak_node
,
max_chunk_region
,
mem_peak
)
best_chunk_region
=
self
.
index_tracer
.
reorder_all
(
best_chunk_region
)
return
best_chunk_region
...
...
@@ -1608,7 +1677,7 @@ class ChunkRegionSearch(object):
_
,
active_node
,
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
chunk_infos
self
.
index_tracer
.
node_list
,
chunk_infos
,
print_mem
=
True
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
...
...
@@ -1736,6 +1805,13 @@ def _replace_name(context, name_from, name_to):
return
context
def
_replace_reshape_size
(
context
,
node_name
,
reshape_size_dict
):
if
node_name
not
in
reshape_size_dict
:
return
context
for
size_name
,
size_value
in
reshape_size_dict
[
node_name
].
items
():
context
=
context
.
replace
(
size_name
,
size_value
)
return
context
def
emit_code_with_chunk
(
body
,
ckpt_func
,
...
...
@@ -1802,11 +1878,12 @@ def emit_code_with_chunk(
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
chunk_slice
=
_gen_chunk_slice_dim
(
dim
,
"chunk_idx"
,
_get_node_shape
(
input_node
)
dim
[
0
]
,
"chunk_idx"
,
_get_node_shape
(
input_node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_search
[
region_idx
][
'reshape_size'
])
body
[
-
1
]
=
" "
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment