Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7fd3b45a
Commit
7fd3b45a
authored
Jan 02, 2023
by
oahzxl
Browse files
fix a bug in ones like, dont gen chunk if dim size is 1
parent
5f24f4fd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
16 deletions
+29
-16
autochunk_benchmark.py
autochunk_benchmark.py
+2
-2
chunk_codegen.py
chunk_codegen.py
+27
-14
No files found.
autochunk_benchmark.py
View file @
7fd3b45a
...
@@ -16,9 +16,9 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N
...
@@ -16,9 +16,9 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
reset_peak_memory_stats
()
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
loop
=
16
loop
=
3
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
_
in
range
(
loop
//
4
):
for
_
in
range
(
loop
//
2
+
1
):
if
chunk_size
:
if
chunk_size
:
model
(
node
,
pair
,
chunk_size
)
model
(
node
,
pair
,
chunk_size
)
else
:
else
:
...
...
chunk_codegen.py
View file @
7fd3b45a
...
@@ -144,9 +144,7 @@ class IndexTracer(object):
...
@@ -144,9 +144,7 @@ class IndexTracer(object):
node_to_trace_source
[
node_to_dim
][
node_from_idx
]
=
[
node_from_dim
]
node_to_trace_source
[
node_to_dim
][
node_from_idx
]
=
[
node_from_dim
]
else
:
else
:
if
node_from_dim
not
in
node_to_trace_source
[
node_to_dim
][
node_from_idx
]:
if
node_from_dim
not
in
node_to_trace_source
[
node_to_dim
][
node_from_idx
]:
node_to_trace_source
[
node_to_dim
][
node_from_idx
].
append
(
node_to_trace_source
[
node_to_dim
][
node_from_idx
].
append
(
node_from_dim
)
node_from_dim
)
# update inputs source
# update inputs source
for
node_idx
,
node_dim
in
node_from_trace_source
[
node_from_dim
].
items
():
for
node_idx
,
node_dim
in
node_from_trace_source
[
node_from_dim
].
items
():
if
node_idx
not
in
node_to_trace_source
[
node_to_dim
]:
if
node_idx
not
in
node_to_trace_source
[
node_to_dim
]:
...
@@ -1472,7 +1470,9 @@ class ChunkSelector(object):
...
@@ -1472,7 +1470,9 @@ class ChunkSelector(object):
regions_dict
=
[]
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
for
region
in
possible_chunk_regions
:
cur_region
=
region
.
copy
()
cur_region
=
region
.
copy
()
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_node_list
,
cur_region
=
self
.
index_tracer
.
tmp_reorder
(
self
.
index_tracer
.
node_list
,
cur_region
)
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_chunk_infos
=
chunk_infos
+
[
cur_region
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_node_list
,
cur_chunk_infos
cur_node_list
,
cur_chunk_infos
...
@@ -1490,7 +1490,7 @@ class ChunkSelector(object):
...
@@ -1490,7 +1490,7 @@ class ChunkSelector(object):
region
[
"region"
][
0
],
region
[
"region"
][
1
]
region
[
"region"
][
0
],
region
[
"region"
][
1
]
),
),
"reorder_chunk_info"
:
cur_region
,
"reorder_chunk_info"
:
cur_region
,
"reorder_node_list"
:
cur_node_list
"reorder_node_list"
:
cur_node_list
,
}
}
)
)
# no region found
# no region found
...
@@ -1508,7 +1508,7 @@ class ChunkSelector(object):
...
@@ -1508,7 +1508,7 @@ class ChunkSelector(object):
def
_get_fit_chunk_size
(
self
,
chunk_region_dict
,
chunk_infos
):
def
_get_fit_chunk_size
(
self
,
chunk_region_dict
,
chunk_infos
):
chunk_size
=
1
chunk_size
=
1
reorder_chunk_info
=
chunk_region_dict
[
'
reorder_chunk_info
'
]
reorder_chunk_info
=
chunk_region_dict
[
"
reorder_chunk_info
"
]
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_max_mem
=
0
cur_chunk_max_mem
=
0
# search a region
# search a region
...
@@ -1517,10 +1517,13 @@ class ChunkSelector(object):
...
@@ -1517,10 +1517,13 @@ class ChunkSelector(object):
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
reorder_chunk_info
[
"chunk_size"
]
=
chunk_size
cur_chunk_infos
=
chunk_infos
+
[
reorder_chunk_info
]
cur_chunk_infos
=
chunk_infos
+
[
reorder_chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
'
reorder_node_list
'
],
cur_chunk_infos
chunk_region_dict
[
"
reorder_node_list
"
],
cur_chunk_infos
)[
0
]
)[
0
]
cur_chunk_max_mem
=
max
(
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
reorder_chunk_info
[
"region"
][
0
]
:
reorder_chunk_info
[
"region"
][
1
]
+
1
]
cur_mem_peak
[
reorder_chunk_info
[
"region"
][
0
]
:
reorder_chunk_info
[
"region"
][
1
]
+
1
]
)
)
# search exact size
# search exact size
chunk_info
=
chunk_region_dict
[
"chunk_info"
]
chunk_info
=
chunk_region_dict
[
"chunk_info"
]
...
@@ -1534,13 +1537,13 @@ class ChunkSelector(object):
...
@@ -1534,13 +1537,13 @@ class ChunkSelector(object):
gap
=
4
gap
=
4
else
:
else
:
gap
=
1
gap
=
1
chunk_info
=
chunk_region_dict
[
'
reorder_chunk_info
'
]
chunk_info
=
chunk_region_dict
[
"
reorder_chunk_info
"
]
while
r
>=
l
+
gap
:
while
r
>=
l
+
gap
:
mid
=
int
((
l
+
r
)
/
2
+
0.5
)
mid
=
int
((
l
+
r
)
/
2
+
0.5
)
chunk_info
[
"chunk_size"
]
=
mid
chunk_info
[
"chunk_size"
]
=
mid
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_chunk_infos
=
chunk_infos
+
[
chunk_info
]
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
cur_mem_peak
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
chunk_region_dict
[
'
reorder_node_list
'
],
cur_chunk_infos
chunk_region_dict
[
"
reorder_node_list
"
],
cur_chunk_infos
)[
0
]
)[
0
]
cur_chunk_max_mem
=
max
(
cur_chunk_max_mem
=
max
(
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
cur_mem_peak
[
chunk_info
[
"region"
][
0
]
:
chunk_info
[
"region"
][
1
]
+
1
]
...
@@ -2000,8 +2003,18 @@ def emit_code_with_chunk(
...
@@ -2000,8 +2003,18 @@ def emit_code_with_chunk(
)
)
# ones like
# ones like
if
"ones_like"
in
node
.
name
:
if
"ones_like"
in
node
.
name
:
chunk_dim
=
chunk_search
[
region_idx
][
"node_chunk_dim"
][
chunk_region_search
.
index_tracer
.
node_list
[
node_idx
]
][
"chunk_dim"
]
if
(
_get_node_shape
(
chunk_region_search
.
index_tracer
.
node_list
[
node_idx
]
)[
chunk_dim
]
==
1
):
continue
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_
search
[
region_idx
][
"node_chunk_dim"
][
chunk_region_search
.
index_tracer
.
node_list
[
node_idx
]][
"chunk_dim"
]
,
"chunk_idx"
,
_get_node_shape
(
node
)
chunk_
dim
,
"chunk_idx"
,
_get_node_shape
(
node
)
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment