Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5cdfcfe1
Commit
5cdfcfe1
authored
Dec 12, 2022
by
oahzxl
Browse files
code style
parent
b7b67c32
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
45 deletions
+4
-45
chunk_codegen.py
chunk_codegen.py
+4
-45
No files found.
chunk_codegen.py
View file @
5cdfcfe1
...
...
@@ -92,24 +92,10 @@ class FlowTracer(object):
self
.
_add_trace
(
i
.
name
)
self
.
_add_node
(
i
.
name
,
i
)
def
_is_non_compute_node
(
self
,
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
_is_non_compute_node_except_placeholder
(
self
,
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
False
def
_find_flow_for_node
(
self
,
node
):
if
type
(
self
.
node_list
[
0
])
!=
type
(
node
):
return
None
if
self
.
_is_non_compute_node_except_placeholder
(
node
):
if
_is_non_compute_node_except_placeholder
(
node
):
return
None
for
name
,
trace
in
self
.
flow_trace
.
items
():
for
i
in
trace
:
...
...
@@ -135,7 +121,7 @@ class FlowTracer(object):
raise
RuntimeError
(
"invalid node"
)
def
_get_flow_mix_node
(
self
,
node
):
if
self
.
_is_non_compute_node
(
node
):
if
_is_non_compute_node
(
node
):
return
None
_
,
node_trace
=
self
.
find_node_flow
(
node
)
if
len
(
node_trace
[
"outside_depend"
])
==
0
:
...
...
@@ -160,10 +146,9 @@ class FlowTracer(object):
for
node
in
self
.
node_list
:
# skip if non compute node
if
all
(
type
(
arg
)
!=
type
(
node
)
or
self
.
_is_non_compute_node_except_placeholder
(
arg
)
type
(
arg
)
!=
type
(
node
)
or
_is_non_compute_node_except_placeholder
(
arg
)
for
arg
in
node
.
args
)
or
self
.
_is_non_compute_node
(
node
):
)
or
_is_non_compute_node
(
node
):
continue
node_input_flows
=
[
self
.
_find_flow_for_node
(
arg
)
for
arg
in
node
.
args
]
...
...
@@ -1411,32 +1396,6 @@ def _gen_loop_end(
return
context
def
_find_input_and_output_nodes
(
nodes
:
List
[
Node
]):
"""
Find the input and output node names which are not found in the given list of nodes.
"""
input_nodes
=
[]
output_nodes
=
[]
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
node_repr
=
repr
(
input_node
)
if
input_node
not
in
nodes
and
input_node
not
in
input_nodes
:
input_nodes
.
append
(
input_node
)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for
node
in
nodes
:
for
output_node
in
node
.
users
.
keys
():
node_repr
=
repr
(
node
)
if
output_node
not
in
nodes
and
output_node
not
in
output_nodes
:
output_nodes
.
append
(
output_node
)
return
input_nodes
,
output_nodes
def
_find_chunk_all_input_nodes
(
nodes
:
List
[
Node
]):
"""
Find non-compute input and output node names.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment