Unverified Commit 43de0118 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

compression speedup: small code refactor (#2065)

parent e6cedb89
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from .infer_shape import ModuleMasks
_logger = logging.getLogger(__name__)
replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask),
......@@ -16,6 +19,7 @@ def no_replace(module, mask):
"""
No need to replace
"""
_logger.debug("no need to replace")
return module
def replace_linear(linear, mask):
......@@ -37,9 +41,8 @@ def replace_linear(linear, mask):
assert mask.output_mask is None
assert not mask.param_masks
index = mask.input_mask.mask_index[-1]
print(mask.input_mask.mask_index)
in_features = index.size()[0]
print('linear: ', in_features)
_logger.debug("replace linear with new in_features: %d", in_features)
new_linear = torch.nn.Linear(in_features=in_features,
out_features=linear.out_features,
bias=linear.bias is not None)
......@@ -67,7 +70,7 @@ def replace_batchnorm2d(norm, mask):
assert 'weight' in mask.param_masks and 'bias' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0]
print("replace batchnorm2d: ", num_features, index)
_logger.debug("replace batchnorm2d with num_features: %d", num_features)
new_norm = torch.nn.BatchNorm2d(num_features=num_features,
eps=norm.eps,
momentum=norm.momentum,
......@@ -106,6 +109,7 @@ def replace_conv2d(conv, mask):
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]
_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
......@@ -128,6 +132,5 @@ def replace_conv2d(conv, mask):
assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks"
new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None:
print('final conv.bias is not None')
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv
......@@ -158,7 +158,7 @@ class ModelSpeedup:
"""
# TODO: scope name could be empty
node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)])
#print('node_name: ', node_name)
_logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1
op_type = node.kind()
......@@ -173,7 +173,6 @@ class ModelSpeedup:
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
#print("predecessor_node: ", predecessor_node)
if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
......@@ -231,7 +230,7 @@ class ModelSpeedup:
"""
graph = self.trace_graph.graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
#print(graph)
#_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = dict()
# build input mapping, from input debugName to its node
......@@ -301,10 +300,8 @@ class ModelSpeedup:
m_inputs.append(_input)
elif not output_to_node[_input] in nodes:
m_inputs.append(_input)
print("module node_name: ", module_name)
if module_name == '':
for n in nodes:
print(n)
_logger.warning("module_name is empty string")
g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes)
self.g_nodes.append(g_node)
......@@ -345,10 +342,7 @@ class ModelSpeedup:
predecessors = []
for _input in self.name_to_gnode[module_name].inputs:
if not _input in self.output_to_gnode:
print(_input)
if not _input in self.output_to_gnode:
# TODO: check _input which does not have node
print("output with no gnode: ", _input)
_logger.debug("cannot find gnode with %s as its output", _input)
else:
g_node = self.output_to_gnode[_input]
predecessors.append(g_node.name)
......@@ -407,15 +401,15 @@ class ModelSpeedup:
self.inferred_masks[module_name] = module_masks
m_type = self.name_to_gnode[module_name].op_type
print("infer_module_mask: {}, module type: {}".format(module_name, m_type))
_logger.debug("infer mask of module %s with op_type %s", module_name, m_type)
if mask is not None:
#print("mask is not None")
_logger.debug("mask is not None")
if not m_type in infer_from_mask:
raise RuntimeError("Has not supported infering \
input/output shape from mask for module/function: `{}`".format(m_type))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
#print("in_shape is not None")
_logger.debug("in_shape is not None")
if not m_type in infer_from_inshape:
raise RuntimeError("Has not supported infering \
output shape from input shape for module/function: `{}`".format(m_type))
......@@ -426,23 +420,19 @@ class ModelSpeedup:
else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None:
#print("out_shape is not None")
_logger.debug("out_shape is not None")
if not m_type in infer_from_outshape:
raise RuntimeError("Has not supported infering \
input shape from output shape for module/function: `{}`".format(m_type))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if input_cmask:
#print("input_cmask is not None")
predecessors = self._find_predecessors(module_name)
for _module_name in predecessors:
print("input_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, out_shape=input_cmask)
if output_cmask:
#print("output_cmask is not None")
successors = self._find_successors(module_name)
for _module_name in successors:
print("output_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, in_shape=output_cmask)
def infer_modules_masks(self):
......@@ -463,16 +453,19 @@ class ModelSpeedup:
"""
for module_name in self.inferred_masks:
g_node = self.name_to_gnode[module_name]
print(module_name, g_node.op_type)
_logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type)
if g_node.type == 'module':
super_module, leaf_module = get_module_by_name(self.bound_model, module_name)
m_type = g_node.op_type
if not m_type in replace_module:
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)", module_name, m_type)
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name])
setattr(super_module, module_name.split('.')[-1], compressed_module)
elif g_node.type == 'func':
print("Warning: Cannot replace func...")
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type)
else:
raise RuntimeError("Unsupported GNode type: {}".format(g_node.type))
......@@ -482,10 +475,12 @@ class ModelSpeedup:
first, do mask/shape inference,
second, replace modules
"""
#print("start to compress")
_logger.info("start to speed up the model")
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info("replace compressed modules...")
self.replace_compressed_modules()
#print("finished compressing")
_logger.info("speedup done")
# resume the model mode to that before the model is speed up
if self.is_training:
self.bound_model.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment