Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
378a49dc
"tests/vscode:/vscode.git/clone" did not exist on "608cffaed3821bacdfce7c44cdf09e6cd38d32c2"
Commit
378a49dc
authored
Dec 27, 2022
by
oahzxl
Browse files
code style
parent
8f5a0edf
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
38 deletions
+63
-38
chunk_codegen.py
chunk_codegen.py
+63
-38
No files found.
chunk_codegen.py
View file @
378a49dc
...
@@ -986,20 +986,20 @@ class IndexTracer(object):
...
@@ -986,20 +986,20 @@ class IndexTracer(object):
return
chunk_info
return
chunk_info
def
_reassgin_reshape_size
(
self
,
chunk_info
):
def
_reassgin_reshape_size
(
self
,
chunk_info
):
chunk_region
=
chunk_info
[
'
region
'
]
chunk_region
=
chunk_info
[
"
region
"
]
reshape_size
=
{}
reshape_size
=
{}
for
node
in
self
.
node_list
[
chunk_region
[
0
]:
chunk_region
[
1
]
+
1
]:
for
node
in
self
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
'
reshape
'
,
'
view
'
]):
if
any
(
i
in
node
.
name
for
i
in
[
"
reshape
"
,
"
view
"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
idx_view_list
[
node
]
reshape_log
=
self
.
idx_view_list
[
node
]
chunk_dim
=
chunk_info
[
'
node_chunk_dim
'
][
node
][
'
chunk_dim
'
]
chunk_dim
=
chunk_info
[
"
node_chunk_dim
"
][
node
][
"
chunk_dim
"
]
reshape_size
[
node
.
name
]
=
{}
reshape_size
[
node
.
name
]
=
{}
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
if
reshape_arg_dim
in
reshape_log
[
'
dim_to
'
]:
if
reshape_arg_dim
in
reshape_log
[
"
dim_to
"
]:
continue
continue
if
reshape_arg_dim
==
chunk_dim
:
if
reshape_arg_dim
==
chunk_dim
:
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
"chunk_size"
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
"chunk_size"
chunk_info
[
'
reshape_size
'
]
=
reshape_size
chunk_info
[
"
reshape_size
"
]
=
reshape_size
return
chunk_info
return
chunk_info
def
_get_reorder_map
(
self
,
chunk_info
):
def
_get_reorder_map
(
self
,
chunk_info
):
...
@@ -1213,7 +1213,7 @@ class MemoryEstimator(object):
...
@@ -1213,7 +1213,7 @@ class MemoryEstimator(object):
if
node
not
in
chunk_node_dim
:
if
node
not
in
chunk_node_dim
:
return
1.0
return
1.0
node_shape
=
_get_node_shape
(
node
)
node_shape
=
_get_node_shape
(
node
)
chunk_dim
=
chunk_node_dim
[
node
][
'
chunk_dim
'
]
chunk_dim
=
chunk_node_dim
[
node
][
"
chunk_dim
"
]
if
chunk_dim
is
None
:
if
chunk_dim
is
None
:
return
1.0
return
1.0
else
:
else
:
...
@@ -1381,7 +1381,9 @@ class MemoryEstimator(object):
...
@@ -1381,7 +1381,9 @@ class MemoryEstimator(object):
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self
.
_print_compute_op_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
self
.
_print_compute_op_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
self
.
_print_compute_op_mem_log
(
act_memory_after_node_log
,
node_list
,
"after"
)
self
.
_print_compute_op_mem_log
(
act_memory_after_node_log
,
node_list
,
"after"
)
# param_memory = parameter_size(gm)
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
# all_memory = act_memory + param_memory
...
@@ -1389,26 +1391,37 @@ class MemoryEstimator(object):
...
@@ -1389,26 +1391,37 @@ class MemoryEstimator(object):
class
ChunkSelector
(
object
):
class
ChunkSelector
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
,
memory_estimator
:
MemoryEstimator
,
stratge
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
,
memory_estimator
:
MemoryEstimator
,
stratge
):
self
.
index_tracer
=
index_tracer
self
.
index_tracer
=
index_tracer
self
.
memory_estimator
=
memory_estimator
self
.
memory_estimator
=
memory_estimator
assert
stratge
in
[
'
min_memory
'
,
'
fit_memory
'
]
assert
stratge
in
[
"
min_memory
"
,
"
fit_memory
"
]
self
.
stratge
=
stratge
self
.
stratge
=
stratge
self
.
max_memory
=
600
# MB
self
.
max_memory
=
600
# MB
def
_select_best_chunk_region
(
self
,
possible_chunk_regions
,
def
_select_best_chunk_region
(
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
if
self
.
stratge
==
'min_memory'
:
):
best_region
=
self
.
_select_min_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
)
if
self
.
stratge
==
"min_memory"
:
elif
self
.
stratge
==
'fit_memory'
:
best_region
=
self
.
_select_min_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
)
elif
self
.
stratge
==
"fit_memory"
:
best_region
=
self
.
_select_fit_memory_chunk_region
(
best_region
=
self
.
_select_fit_memory_chunk_region
(
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
)
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
,
)
else
:
else
:
raise
RuntimeError
()
raise
RuntimeError
()
return
best_region
return
best_region
def
_select_fit_memory_chunk_region
(
self
,
possible_chunk_regions
,
def
_select_fit_memory_chunk_region
(
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
self
,
possible_chunk_regions
,
chunk_infos
,
peak_node
,
max_chunk_region
,
mem_peak
):
# stop chunk if max memory satisfy memory limit
# stop chunk if max memory satisfy memory limit
if
max
(
mem_peak
)
<
self
.
max_memory
:
if
max
(
mem_peak
)
<
self
.
max_memory
:
return
None
return
None
...
@@ -1427,15 +1440,22 @@ class ChunkSelector(object):
...
@@ -1427,15 +1440,22 @@ class ChunkSelector(object):
for
region
in
possible_chunk_regions
:
for
region
in
possible_chunk_regions
:
cur_chunk_infos
=
chunk_infos
+
[
region
]
cur_chunk_infos
=
chunk_infos
+
[
region
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
cur_chunk_infos
)[
0
]
self
.
index_tracer
.
node_list
,
cur_chunk_infos
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]:
max_chunk_region
[
1
]
+
1
]
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
]
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
cur_chunk_region_max_peak
=
max
(
cur_chunk_region_peak
)
if
cur_chunk_region_max_peak
<
self
.
max_memory
:
if
cur_chunk_region_max_peak
<
self
.
max_memory
:
regions_dict
.
append
({
regions_dict
.
append
(
{
"chunk_info"
:
region
,
"chunk_info"
:
region
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_max_mem"
:
cur_chunk_region_max_peak
,
"chunk_len"
:
self
.
_get_compute_node_num
(
region
[
'region'
][
0
],
region
[
'region'
][
1
]),
"chunk_len"
:
self
.
_get_compute_node_num
(
})
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
}
)
# no region found
# no region found
if
len
(
regions_dict
)
==
0
:
if
len
(
regions_dict
)
==
0
:
return
None
return
None
...
@@ -1448,7 +1468,7 @@ class ChunkSelector(object):
...
@@ -1448,7 +1468,7 @@ class ChunkSelector(object):
def
_get_compute_node_num
(
self
,
start
,
end
):
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
count
=
0
for
i
in
self
.
index_tracer
.
node_list
[
start
:
end
+
1
]:
for
i
in
self
.
index_tracer
.
node_list
[
start
:
end
+
1
]:
if
_is_non_compute_node
(
i
):
if
_is_non_compute_node
(
i
):
count
+=
1
count
+=
1
return
count
return
count
...
@@ -1490,7 +1510,9 @@ class ChunkRegionSearch(object):
...
@@ -1490,7 +1510,9 @@ class ChunkRegionSearch(object):
self
.
index_tracer
=
IndexTracer
(
list
(
gm
.
graph
.
nodes
))
self
.
index_tracer
=
IndexTracer
(
list
(
gm
.
graph
.
nodes
))
self
.
index_tracer
.
trace_index
()
self
.
index_tracer
.
trace_index
()
self
.
memory_estimator
=
MemoryEstimator
(
self
.
index_tracer
)
self
.
memory_estimator
=
MemoryEstimator
(
self
.
index_tracer
)
self
.
chunk_selector
=
ChunkSelector
(
self
.
index_tracer
,
self
.
memory_estimator
,
stratge
=
"fit_memory"
)
self
.
chunk_selector
=
ChunkSelector
(
self
.
index_tracer
,
self
.
memory_estimator
,
stratge
=
"fit_memory"
)
def
_find_peak_node
(
self
,
mem_peak
):
def
_find_peak_node
(
self
,
mem_peak
):
max_value
=
max
(
mem_peak
)
max_value
=
max
(
mem_peak
)
...
@@ -1812,6 +1834,7 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
...
@@ -1812,6 +1834,7 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
context
=
context
.
replace
(
size_name
,
size_value
)
context
=
context
.
replace
(
size_name
,
size_value
)
return
context
return
context
def
emit_code_with_chunk
(
def
emit_code_with_chunk
(
body
,
body
,
ckpt_func
,
ckpt_func
,
...
@@ -1883,7 +1906,9 @@ def emit_code_with_chunk(
...
@@ -1883,7 +1906,9 @@ def emit_code_with_chunk(
body
[
-
1
]
=
_replace_name
(
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
)
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_search
[
region_idx
][
'reshape_size'
])
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_search
[
region_idx
][
"reshape_size"
]
)
body
[
-
1
]
=
" "
+
body
[
-
1
]
body
[
-
1
]
=
" "
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment