Unverified Commit 6e629908 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[BUG] finding leaf modules (#2241)

parent 5c8cb258
...@@ -229,42 +229,39 @@ class ModelSpeedup: ...@@ -229,42 +229,39 @@ class ModelSpeedup:
list list
a list of scope name of all the leaf modules a list of scope name of all the leaf modules
""" """
pieces = [] # each element is a dict class SNode:
def __init__(self, name):
self.sname = name
self.childs = {}
root = None
for node in graph.nodes(): for node in graph.nodes():
scope_name = node.scopeName() scope_name = node.scopeName()
if scope_name == '': if scope_name == '':
continue continue
segs = scope_name.split('/') segs = scope_name.split('/')
segs_len = len(segs) if root is None:
# increase the length of `pieces` if not enough root = SNode(segs[0])
for _ in range(segs_len - len(pieces)): curr = root
pieces.append({}) for seg in segs[1:]:
# process internal segments of the scope name if not seg in curr.childs:
# 'L' means leaf segment curr.childs[seg] = SNode(seg)
# 'I' means internal segment curr = curr.childs[seg]
# internal segment can replace leaf segment at the same position of `pieces`
for i, seg in enumerate(segs[:-1]): leaf_nodes = []
seg_name_dict = pieces[i] def traverse_tree(node, scope_name):
if seg in seg_name_dict: if scope_name == '':
if seg_name_dict[seg][0] == 'L': sn = node.sname
seg_name_dict[seg] = ('I', node) else:
else: sn = scope_name + '/' + node.sname
seg_name_dict[seg] = ('I', node) if not node.childs:
# process the leaf segment of the scope name if node.sname[-1] == ']':
last_segs_dict = pieces[len(segs) - 1] leaf_nodes.append(sn)
if not segs[-1] in last_segs_dict: else:
last_segs_dict[segs[-1]] = ('L', node) for key in node.childs:
# traverse `pieces` to obtain all the leaf modules which are labeled with 'L' traverse_tree(node.childs[key], sn)
leaf_modules = [] traverse_tree(root, '')
for piece in pieces: return leaf_nodes
for _, value in piece.items():
if value[0] == 'L':
assert value[1].scopeName() not in leaf_modules
# if this is a leaf module, the last segment of its scope name
# must be in pattern `xxx[xxx]`
if value[1].scopeName()[-1] == ']':
leaf_modules.append(value[1].scopeName())
return leaf_modules
def _build_graph(self): def _build_graph(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment