Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6e629908
Unverified
Commit
6e629908
authored
Mar 27, 2020
by
QuanluZhang
Committed by
GitHub
Mar 27, 2020
Browse files
[BUG] finding leaf modules (#2241)
parent
5c8cb258
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
31 deletions
+28
-31
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
+28
-31
No files found.
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
View file @
6e629908
...
@@ -229,42 +229,39 @@ class ModelSpeedup:
...
@@ -229,42 +229,39 @@ class ModelSpeedup:
list
list
a list of scope name of all the leaf modules
a list of scope name of all the leaf modules
"""
"""
pieces
=
[]
# each element is a dict
class
SNode
:
def
__init__
(
self
,
name
):
self
.
sname
=
name
self
.
childs
=
{}
root
=
None
for
node
in
graph
.
nodes
():
for
node
in
graph
.
nodes
():
scope_name
=
node
.
scopeName
()
scope_name
=
node
.
scopeName
()
if
scope_name
==
''
:
if
scope_name
==
''
:
continue
continue
segs
=
scope_name
.
split
(
'/'
)
segs
=
scope_name
.
split
(
'/'
)
segs_len
=
len
(
segs
)
if
root
is
None
:
# increase the length of `pieces` if not enough
root
=
SNode
(
segs
[
0
])
for
_
in
range
(
segs_len
-
len
(
pieces
)):
curr
=
root
pieces
.
append
({})
for
seg
in
segs
[
1
:]:
# process internal segments of the scope name
if
not
seg
in
curr
.
childs
:
# 'L' means leaf segment
curr
.
childs
[
seg
]
=
SNode
(
seg
)
# 'I' means internal segment
curr
=
curr
.
childs
[
seg
]
# internal segment can replace leaf segment at the same position of `pieces`
for
i
,
seg
in
enumerate
(
segs
[:
-
1
]):
leaf_nodes
=
[]
seg_name_dict
=
pieces
[
i
]
def
traverse_tree
(
node
,
scope_name
):
if
seg
in
seg_name_dict
:
if
scope_name
==
''
:
if
seg_name_dict
[
seg
][
0
]
==
'L'
:
sn
=
node
.
sname
seg_name_dict
[
seg
]
=
(
'I'
,
node
)
else
:
else
:
sn
=
scope_name
+
'/'
+
node
.
sname
seg_name_dict
[
seg
]
=
(
'I'
,
node
)
if
not
node
.
childs
:
# process the leaf segment of the scope name
if
node
.
sname
[
-
1
]
==
']'
:
last_segs_dict
=
pieces
[
len
(
segs
)
-
1
]
leaf_nodes
.
append
(
sn
)
if
not
segs
[
-
1
]
in
last_segs_dict
:
else
:
last_segs_dict
[
segs
[
-
1
]]
=
(
'L'
,
node
)
for
key
in
node
.
childs
:
# traverse `pieces` to obtain all the leaf modules which are labeled with 'L'
traverse_tree
(
node
.
childs
[
key
],
sn
)
leaf_modules
=
[]
traverse_tree
(
root
,
''
)
for
piece
in
pieces
:
return
leaf_nodes
for
_
,
value
in
piece
.
items
():
if
value
[
0
]
==
'L'
:
assert
value
[
1
].
scopeName
()
not
in
leaf_modules
# if this is a leaf module, the last segment of its scope name
# must be in pattern `xxx[xxx]`
if
value
[
1
].
scopeName
()[
-
1
]
==
']'
:
leaf_modules
.
append
(
value
[
1
].
scopeName
())
return
leaf_modules
def
_build_graph
(
self
):
def
_build_graph
(
self
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment