Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f24c418b
Commit
f24c418b
authored
Dec 06, 2022
by
oahzxl
Browse files
finish chunk define
parent
3b7d6712
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
4 deletions
+13
-4
chunk_codegen.py
chunk_codegen.py
+13
-4
No files found.
chunk_codegen.py
View file @
f24c418b
...
...
@@ -827,7 +827,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
node_repr
=
repr
(
input_node
)
if
input_node
not
in
nodes
and
node_repr
not
in
input_nodes
:
if
input_node
not
in
nodes
and
input_node
not
in
input_nodes
:
input_nodes
.
append
(
input_node
)
# if a node has a user node which is not in the node list
...
...
@@ -835,7 +835,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for
node
in
nodes
:
for
output_node
in
node
.
users
.
keys
():
node_repr
=
repr
(
node
)
if
output_node
not
in
nodes
and
node_repr
not
in
output_nodes
:
if
output_node
not
in
nodes
and
output_node
not
in
output_nodes
:
output_nodes
.
append
(
output_node
)
return
input_nodes
,
output_nodes
...
...
@@ -848,6 +848,16 @@ def _find_idx_by_name(name, nodes_list):
raise
RuntimeError
(
"name %s not found in node list"
%
name
)
def
_replace_name
(
context
,
name_from
,
name_to
):
patterns
=
[(
" "
,
" "
),
(
" "
,
"."
),
(
" "
,
","
),
(
"("
,
")"
),
(
"("
,
","
)]
for
p
in
patterns
:
source
=
p
[
0
]
+
name_from
+
p
[
1
]
target
=
p
[
0
]
+
name_to
+
p
[
1
]
if
source
in
context
:
context
=
context
.
replace
(
source
,
target
)
return
context
def
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node_func
,
delete_unused_value_func
,
meta_nodes
,
meta_graph
):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
...
...
@@ -905,8 +915,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
# replace input var with chunk var
if
node_idx
in
chunk_starts
:
body
[
-
1
]
=
body
[
-
1
].
replace
(
chunk_inputs
[
region_idx
][
0
].
name
,
'chunk_tensor'
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
chunk_inputs
[
region_idx
][
0
].
name
,
'chunk_tensor'
)
body
[
-
1
]
=
' '
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment