Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5f24f4fd
Commit
5f24f4fd
authored
Dec 31, 2022
by
oahzxl
Browse files
support ones_like, add prompt if fit mode search fail
parent
80efd70c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
4 deletions
+15
-4
chunk_codegen.py
chunk_codegen.py
+15
-4
No files found.
chunk_codegen.py
View file @
5f24f4fd
...
@@ -1406,9 +1406,9 @@ class MemoryEstimator(object):
...
@@ -1406,9 +1406,9 @@ class MemoryEstimator(object):
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self
.
_print_compute_op_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
self
.
_print_compute_op_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
self
.
_print_compute_op_mem_log
(
#
self._print_compute_op_mem_log(
act_memory_after_node_log
,
node_list
,
"after"
#
act_memory_after_node_log, node_list, "after"
)
#
)
# param_memory = parameter_size(gm)
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
# all_memory = act_memory + param_memory
...
@@ -1465,6 +1465,9 @@ class ChunkSelector(object):
...
@@ -1465,6 +1465,9 @@ class ChunkSelector(object):
if
i
in
possible_chunk_regions
:
if
i
in
possible_chunk_regions
:
possible_chunk_regions
.
remove
(
i
)
possible_chunk_regions
.
remove
(
i
)
if
len
(
possible_chunk_regions
)
==
0
:
return
None
# get mem for chunk region
# get mem for chunk region
regions_dict
=
[]
regions_dict
=
[]
for
region
in
possible_chunk_regions
:
for
region
in
possible_chunk_regions
:
...
@@ -1492,7 +1495,7 @@ class ChunkSelector(object):
...
@@ -1492,7 +1495,7 @@ class ChunkSelector(object):
)
)
# no region found
# no region found
if
len
(
regions_dict
)
==
0
:
if
len
(
regions_dict
)
==
0
:
r
eturn
None
r
aise
RuntimeError
(
"Search failed. Try a larger memory threshold."
)
# select the min chunk len
# select the min chunk len
chunk_len
=
[
i
[
"chunk_len"
]
for
i
in
regions_dict
]
chunk_len
=
[
i
[
"chunk_len"
]
for
i
in
regions_dict
]
...
@@ -1995,6 +1998,14 @@ def emit_code_with_chunk(
...
@@ -1995,6 +1998,14 @@ def emit_code_with_chunk(
body
[
-
1
]
=
_replace_name
(
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
)
# ones like
if
"ones_like"
in
node
.
name
:
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_search
[
region_idx
][
"node_chunk_dim"
][
chunk_region_search
.
index_tracer
.
node_list
[
node_idx
]][
"chunk_dim"
],
"chunk_idx"
,
_get_node_shape
(
node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
)
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_search
[
region_idx
][
"reshape_size"
]
body
[
-
1
],
node
.
name
,
chunk_search
[
region_idx
][
"reshape_size"
]
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment