Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6be89a3b
Commit
6be89a3b
authored
Dec 27, 2022
by
oahzxl
Browse files
add chunksize in emit, fix bug in reassgin shape
parent
378a49dc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
4 deletions
+52
-4
chunk_codegen.py
chunk_codegen.py
+52
-4
No files found.
chunk_codegen.py
View file @
6be89a3b
...
@@ -988,6 +988,7 @@ class IndexTracer(object):
...
@@ -988,6 +988,7 @@ class IndexTracer(object):
def
_reassgin_reshape_size
(
self
,
chunk_info
):
def
_reassgin_reshape_size
(
self
,
chunk_info
):
chunk_region
=
chunk_info
[
"region"
]
chunk_region
=
chunk_info
[
"region"
]
reshape_size
=
{}
reshape_size
=
{}
chunk_shape
=
_get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]]
for
node
in
self
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
for
node
in
self
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_args
=
node
.
args
[
1
:]
...
@@ -998,7 +999,7 @@ class IndexTracer(object):
...
@@ -998,7 +999,7 @@ class IndexTracer(object):
if
reshape_arg_dim
in
reshape_log
[
"dim_to"
]:
if
reshape_arg_dim
in
reshape_log
[
"dim_to"
]:
continue
continue
if
reshape_arg_dim
==
chunk_dim
:
if
reshape_arg_dim
==
chunk_dim
:
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
"chunk_size
"
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
"
min(
chunk_size
, %d - chunk_idx)"
%
chunk_shape
chunk_info
[
"reshape_size"
]
=
reshape_size
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
return
chunk_info
...
@@ -1276,7 +1277,6 @@ class MemoryEstimator(object):
...
@@ -1276,7 +1277,6 @@ class MemoryEstimator(object):
chunk_within
=
False
chunk_within
=
False
chunk_region_idx
=
None
chunk_region_idx
=
None
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_size
=
1
chunk_inputs_names
=
[]
chunk_inputs_names
=
[]
if
use_chunk
:
if
use_chunk
:
...
@@ -1285,12 +1285,14 @@ class MemoryEstimator(object):
...
@@ -1285,12 +1285,14 @@ class MemoryEstimator(object):
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_node_dim
=
[
i
[
"node_chunk_dim"
]
for
i
in
chunk_infos
]
chunk_node_dim
=
[
i
[
"node_chunk_dim"
]
for
i
in
chunk_infos
]
chunk_sizes
=
[
i
[
"chunk_size"
]
if
"chunk_size"
in
i
else
1
for
i
in
chunk_infos
]
for
idx
,
node
in
enumerate
(
node_list
):
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
...
@@ -1306,7 +1308,7 @@ class MemoryEstimator(object):
...
@@ -1306,7 +1308,7 @@ class MemoryEstimator(object):
chunk_ratio
=
self
.
_get_chunk_ratio
(
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
node
,
chunk_node_dim
[
chunk_region_idx
],
chunk_node_dim
[
chunk_region_idx
],
chunk_size
,
chunk_size
s
[
chunk_region_idx
]
,
)
)
# if node is placeholder, just add the size of the node
# if node is placeholder, just add the size of the node
...
@@ -1464,8 +1466,53 @@ class ChunkSelector(object):
...
@@ -1464,8 +1466,53 @@ class ChunkSelector(object):
chunk_len
=
[
i
[
"chunk_len"
]
for
i
in
regions_dict
]
chunk_len
=
[
i
[
"chunk_len"
]
for
i
in
regions_dict
]
best_region_idx
=
chunk_len
.
index
(
min
(
chunk_len
))
best_region_idx
=
chunk_len
.
index
(
min
(
chunk_len
))
best_region
=
regions_dict
[
best_region_idx
][
"chunk_info"
]
best_region
=
regions_dict
[
best_region_idx
][
"chunk_info"
]
# get max chunk size
best_region
=
self
.
_get_fit_chunk_size
(
best_region
,
chunk_infos
)
return
best_region
return
best_region
def
_get_fit_chunk_size
(
self
,
chunk_info
,
chunk_infos
):
chunk_size
=
1
chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_max_mem
=
0
# search a region
while
cur_chunk_max_mem
<
self
.
max_memory
:
chunk_size
*=
2
chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
)
# search exact size
chunk_info
[
"chunk_size"
]
=
self
.
_chunk_size_binary_search
(
chunk_size
//
2
,
chunk_size
,
chunk_info
,
chunk_infos
)
return
chunk_info
def
_chunk_size_binary_search
(
self
,
l
,
r
,
chunk_info
,
chunk_infos
):
if
l
>=
16
:
gap
=
4
else
:
gap
=
1
while
r
>=
l
+
gap
:
mid
=
int
(
l
+
(
r
-
l
)
/
2
)
chunk_info
[
"chunk_size"
]
=
mid
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
index_tracer
.
node_list
,
cur_chunk_infos
)[
0
]
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
)
if
cur_chunk_max_mem
>=
self
.
max_memory
:
r
=
mid
-
gap
else
:
l
=
mid
+
gap
return
l
def
_get_compute_node_num
(
self
,
start
,
end
):
def
_get_compute_node_num
(
self
,
start
,
end
):
count
=
0
count
=
0
for
i
in
self
.
index_tracer
.
node_list
[
start
:
end
+
1
]:
for
i
in
self
.
index_tracer
.
node_list
[
start
:
end
+
1
]:
...
@@ -1891,6 +1938,7 @@ def emit_code_with_chunk(
...
@@ -1891,6 +1938,7 @@ def emit_code_with_chunk(
chunk_inputs
[
region_idx
],
chunk_inputs
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
chunk_size
=
chunk_search
[
region_idx
][
"chunk_size"
]
)
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment