Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
966e4ea0
"...Chat/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "6afeb1202aeedc68d9ef1f77b41f3e5f55e0e121"
Commit
966e4ea0
authored
Dec 31, 2022
by
oahzxl
Browse files
add reorder in mem estimator
parent
e5a5fbb8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
11 deletions
+32
-11
chunk_codegen.py
chunk_codegen.py
+32
-11
No files found.
chunk_codegen.py
View file @
966e4ea0
...
...
@@ -1040,11 +1040,13 @@ class IndexTracer(object):
chunk_info
[
"region"
][
0
]
+
len
(
chunk_info
[
"args"
][
"prepose_nodes"
]),
chunk_info
[
"region"
][
1
],
)
new_inputs_dim
=
[]
for
idx
,
input_dim
in
enumerate
(
chunk_info
[
"inputs_dim"
]):
new_input_dim
=
{}
for
k
,
v
in
input_dim
.
items
():
new_input_dim
[
reorder_map
[
k
]]
=
v
chunk_info
[
"inputs_dim"
][
idx
]
=
new_input_dim
new_inputs_dim
.
append
(
new_input_dim
)
chunk_info
[
"inputs_dim"
]
=
new_inputs_dim
return
chunk_info
def
_update_all_reorder_map
(
self
,
reorder_map
):
...
...
@@ -1095,11 +1097,24 @@ class IndexTracer(object):
for
old_idx
,
new_idx
in
self
.
all_reorder_map
.
items
():
new_node_list
[
new_idx
]
=
node_list
[
old_idx
]
return
new_node_list
def
tmp_reorder
(
self
,
node_list
,
chunk_info
):
if
len
(
chunk_info
[
"args"
][
"prepose_nodes"
])
==
0
:
return
node_list
,
chunk_info
reorder_map
=
self
.
_get_reorder_map
(
chunk_info
)
# new tmp node list
new_node_list
=
[
None
for
_
in
range
(
len
(
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
node_list
[
old_idx
]
chunk_info
=
self
.
_reorder_chunk_info
(
chunk_info
,
reorder_map
)
return
new_node_list
,
chunk_info
class
MemoryEstimator
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
self
.
index_tracer
=
index_tracer
pass
def
_get_meta_node_size
(
self
,
x
):
x
=
x
.
meta
[
"tensor_meta"
]
...
...
@@ -1453,9 +1468,11 @@ class ChunkSelector(object):
# get mem for chunk region
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
cur_chunk_infos
=
chunk_infos
+
[
region
]
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
cur_chunk_infos
cur_
node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_region_peak
=
cur_mem_peak
[
max_chunk_region
[
0
]
:
max_chunk_region
[
1
]
+
1
...
...
@@ -1492,9 +1509,11 @@ class ChunkSelector(object):
while
cur_chunk_max_mem
<
self
.
max_memory
:
chunk_size
*=
2
chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_chunk_info
=
chunk_info
.
copy
()
cur_node_list
,
cur_chunk_info
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_chunk_info
)
cur_chunk_infos
=
chunk_infos
+
[
cur_chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
cur_chunk_infos
cur_
node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
...
...
@@ -1511,11 +1530,13 @@ class ChunkSelector(object):
else
:
gap
=
1
while
r
>=
l
+
gap
:
mid
=
int
(
l
+
(
r
-
l
)
/
2
)
mid
=
int
(
(
l
+
r
)
/
2
+
0.5
)
chunk_info
[
"chunk_size"
]
=
mid
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_chunk_info
=
chunk_info
.
copy
()
cur_node_list
,
cur_chunk_info
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_chunk_info
)
cur_chunk_infos
=
chunk_infos
+
[
cur_chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
cur_chunk_infos
cur_
node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
...
...
@@ -1529,7 +1550,7 @@ class ChunkSelector(object):
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
for
i
in
self
.
index_tracer
.
node_list
[
start
:
end
+
1
]:
if
_is_non_compute_node
(
i
):
if
not
_is_non_compute_node
(
i
):
count
+=
1
return
count
...
...
@@ -1547,7 +1568,7 @@ class ChunkSelector(object):
max_region_range
=
0
best_region
=
None
if
best_region
is
not
None
:
best_region
[
"chunk_size"
]
=
2
best_region
[
"chunk_size"
]
=
1
return
best_region
def
_is_legal_region
(
self
,
cur_chunk_info
,
chunk_infos
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment