Unverified Commit ebcd6024 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

compression speedup: fix bug (#2072)

parent a09d3581
......@@ -11,6 +11,7 @@ replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask),
'MaxPool2d': lambda module, mask: no_replace(module, mask),
'AvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask)
}
......
......@@ -210,6 +210,60 @@ class ModelSpeedup:
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
def _extract_leaf_modules(self, graph):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Parameters
----------
graph : jit trace graph
the graph generated from jit trace
Returns
-------
list
a list of scope name of all the leaf modules
"""
pieces = [] # each element is a dict
for node in graph.nodes():
scope_name = node.scopeName()
if scope_name == '':
continue
segs = scope_name.split('/')
segs_len = len(segs)
# increase the length of `pieces` if not enough
for _ in range(segs_len - len(pieces)):
pieces.append({})
# process internal segments of the scope name
# 'L' means leaf segment
# 'I' means internal segment
# internal segment can replace leaf segment at the same position of `pieces`
for i, seg in enumerate(segs[:-1]):
seg_name_dict = pieces[i]
if seg in seg_name_dict:
if seg_name_dict[seg][0] == 'L':
seg_name_dict[seg] = ('I', node)
else:
seg_name_dict[seg] = ('I', node)
# process the leaf segment of the scope name
last_segs_dict = pieces[len(segs) - 1]
if not segs[-1] in last_segs_dict:
last_segs_dict[segs[-1]] = ('L', node)
# traverse `pieces` to obtain all the leaf modules which are labeled with 'L'
leaf_modules = []
for piece in pieces:
for _, value in piece.items():
if value[0] == 'L':
assert value[1].scopeName() not in leaf_modules
# if this is a leaf module, the last segment of its scope name
# must be in pattern `xxx[xxx]`
if value[1].scopeName()[-1] == ']':
leaf_modules.append(value[1].scopeName())
return leaf_modules
def _build_graph(self):
"""
Build graph using our defined format from jit trace.
......@@ -230,7 +284,7 @@ class ModelSpeedup:
"""
graph = self.trace_graph.graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
#_logger.debug(graph)
_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = dict()
# build input mapping, from input debugName to its node
......@@ -249,6 +303,9 @@ class ModelSpeedup:
for output in graph.outputs():
graph_outputs.append(output.debugName())
leaf_modules = self._extract_leaf_modules(graph)
_logger.debug(leaf_modules)
for node in graph.nodes():
# populate output_to_node and input_to_node
for output in node.outputs():
......@@ -258,10 +315,8 @@ class ModelSpeedup:
input_name = _input.debugName()
input_to_node[input_name] = node
scope_name = node.scopeName() # example: scope_name, 'MyCell/Linear[linear]'
module_name_slices = re.findall(r'\[(.*?)\]', scope_name)
module_name = '.'.join(module_name_slices)
# if module_name is empty, it is not a module
if module_name == '':
if not scope_name in leaf_modules:
if scope_name == '':
continue
else:
......@@ -270,6 +325,8 @@ class ModelSpeedup:
else:
func_to_nodes[scope_name] = [node]
else:
module_name_slices = re.findall(r'\[(.*?)\]', scope_name)
module_name = '.'.join(module_name_slices)
scope_slice = scope_name.split('/')[-1]
module_type = scope_slice.split('[')[0]
module_to_type[module_name] = module_type
......@@ -405,14 +462,16 @@ class ModelSpeedup:
if mask is not None:
_logger.debug("mask is not None")
if not m_type in infer_from_mask:
raise RuntimeError("Has not supported infering \
input/output shape from mask for module/function: `{}`".format(m_type))
raise RuntimeError(
"Has not supported infering input/output shape from mask for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
_logger.debug("in_shape is not None")
if not m_type in infer_from_inshape:
raise RuntimeError("Has not supported infering \
output shape from input shape for module/function: `{}`".format(m_type))
raise RuntimeError(
"Has not supported infering output shape from input shape for module/function: `{}`, {}"
.format(m_type, module_name))
if m_type == 'aten::view':
output_cmask = infer_from_inshape[m_type](module_masks,
in_shape,
......@@ -422,8 +481,9 @@ class ModelSpeedup:
if out_shape is not None:
_logger.debug("out_shape is not None")
if not m_type in infer_from_outshape:
raise RuntimeError("Has not supported infering \
input shape from output shape for module/function: `{}`".format(m_type))
raise RuntimeError(
"Has not supported infering input shape from output shape for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if input_cmask:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment