Unverified Commit 43de0118 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

compression speedup: small code refactor (#2065)

parent e6cedb89
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import logging
import torch import torch
from .infer_shape import ModuleMasks from .infer_shape import ModuleMasks
_logger = logging.getLogger(__name__)
replace_module = { replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask), 'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask), 'Conv2d': lambda module, mask: replace_conv2d(module, mask),
...@@ -16,6 +19,7 @@ def no_replace(module, mask): ...@@ -16,6 +19,7 @@ def no_replace(module, mask):
""" """
No need to replace No need to replace
""" """
_logger.debug("no need to replace")
return module return module
def replace_linear(linear, mask): def replace_linear(linear, mask):
...@@ -37,9 +41,8 @@ def replace_linear(linear, mask): ...@@ -37,9 +41,8 @@ def replace_linear(linear, mask):
assert mask.output_mask is None assert mask.output_mask is None
assert not mask.param_masks assert not mask.param_masks
index = mask.input_mask.mask_index[-1] index = mask.input_mask.mask_index[-1]
print(mask.input_mask.mask_index)
in_features = index.size()[0] in_features = index.size()[0]
print('linear: ', in_features) _logger.debug("replace linear with new in_features: %d", in_features)
new_linear = torch.nn.Linear(in_features=in_features, new_linear = torch.nn.Linear(in_features=in_features,
out_features=linear.out_features, out_features=linear.out_features,
bias=linear.bias is not None) bias=linear.bias is not None)
...@@ -67,7 +70,7 @@ def replace_batchnorm2d(norm, mask): ...@@ -67,7 +70,7 @@ def replace_batchnorm2d(norm, mask):
assert 'weight' in mask.param_masks and 'bias' in mask.param_masks assert 'weight' in mask.param_masks and 'bias' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0] index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0] num_features = index.size()[0]
print("replace batchnorm2d: ", num_features, index) _logger.debug("replace batchnorm2d with num_features: %d", num_features)
new_norm = torch.nn.BatchNorm2d(num_features=num_features, new_norm = torch.nn.BatchNorm2d(num_features=num_features,
eps=norm.eps, eps=norm.eps,
momentum=norm.momentum, momentum=norm.momentum,
...@@ -106,6 +109,7 @@ def replace_conv2d(conv, mask): ...@@ -106,6 +109,7 @@ def replace_conv2d(conv, mask):
else: else:
out_channels_index = mask.output_mask.mask_index[1] out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0] out_channels = out_channels_index.size()[0]
_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels, new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=conv.kernel_size, kernel_size=conv.kernel_size,
...@@ -128,6 +132,5 @@ def replace_conv2d(conv, mask): ...@@ -128,6 +132,5 @@ def replace_conv2d(conv, mask):
assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks" assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks"
new_conv.weight.data.copy_(tmp_weight_data) new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None: if conv.bias is not None:
print('final conv.bias is not None')
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data) new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv return new_conv
...@@ -158,7 +158,7 @@ class ModelSpeedup: ...@@ -158,7 +158,7 @@ class ModelSpeedup:
""" """
# TODO: scope name could be empty # TODO: scope name could be empty
node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)]) node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)])
#print('node_name: ', node_name) _logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1 self.global_count += 1
op_type = node.kind() op_type = node.kind()
...@@ -173,7 +173,6 @@ class ModelSpeedup: ...@@ -173,7 +173,6 @@ class ModelSpeedup:
input_name = _input.debugName() input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes: if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name] predecessor_node = output_to_node[input_name]
#print("predecessor_node: ", predecessor_node)
if predecessor_node.kind().startswith('prim::'): if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node) node_group.append(predecessor_node)
node_queue.put(predecessor_node) node_queue.put(predecessor_node)
...@@ -231,7 +230,7 @@ class ModelSpeedup: ...@@ -231,7 +230,7 @@ class ModelSpeedup:
""" """
graph = self.trace_graph.graph graph = self.trace_graph.graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here # if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
#print(graph) #_logger.debug(graph)
# build output mapping, from output debugName to its node # build output mapping, from output debugName to its node
output_to_node = dict() output_to_node = dict()
# build input mapping, from input debugName to its node # build input mapping, from input debugName to its node
...@@ -301,10 +300,8 @@ class ModelSpeedup: ...@@ -301,10 +300,8 @@ class ModelSpeedup:
m_inputs.append(_input) m_inputs.append(_input)
elif not output_to_node[_input] in nodes: elif not output_to_node[_input] in nodes:
m_inputs.append(_input) m_inputs.append(_input)
print("module node_name: ", module_name)
if module_name == '': if module_name == '':
for n in nodes: _logger.warning("module_name is empty string")
print(n)
g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes) g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes)
self.g_nodes.append(g_node) self.g_nodes.append(g_node)
...@@ -345,10 +342,7 @@ class ModelSpeedup: ...@@ -345,10 +342,7 @@ class ModelSpeedup:
predecessors = [] predecessors = []
for _input in self.name_to_gnode[module_name].inputs: for _input in self.name_to_gnode[module_name].inputs:
if not _input in self.output_to_gnode: if not _input in self.output_to_gnode:
print(_input) _logger.debug("cannot find gnode with %s as its output", _input)
if not _input in self.output_to_gnode:
# TODO: check _input which does not have node
print("output with no gnode: ", _input)
else: else:
g_node = self.output_to_gnode[_input] g_node = self.output_to_gnode[_input]
predecessors.append(g_node.name) predecessors.append(g_node.name)
...@@ -407,15 +401,15 @@ class ModelSpeedup: ...@@ -407,15 +401,15 @@ class ModelSpeedup:
self.inferred_masks[module_name] = module_masks self.inferred_masks[module_name] = module_masks
m_type = self.name_to_gnode[module_name].op_type m_type = self.name_to_gnode[module_name].op_type
print("infer_module_mask: {}, module type: {}".format(module_name, m_type)) _logger.debug("infer mask of module %s with op_type %s", module_name, m_type)
if mask is not None: if mask is not None:
#print("mask is not None") _logger.debug("mask is not None")
if not m_type in infer_from_mask: if not m_type in infer_from_mask:
raise RuntimeError("Has not supported infering \ raise RuntimeError("Has not supported infering \
input/output shape from mask for module/function: `{}`".format(m_type)) input/output shape from mask for module/function: `{}`".format(m_type))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask) input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None: if in_shape is not None:
#print("in_shape is not None") _logger.debug("in_shape is not None")
if not m_type in infer_from_inshape: if not m_type in infer_from_inshape:
raise RuntimeError("Has not supported infering \ raise RuntimeError("Has not supported infering \
output shape from input shape for module/function: `{}`".format(m_type)) output shape from input shape for module/function: `{}`".format(m_type))
...@@ -426,23 +420,19 @@ class ModelSpeedup: ...@@ -426,23 +420,19 @@ class ModelSpeedup:
else: else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape) output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None: if out_shape is not None:
#print("out_shape is not None") _logger.debug("out_shape is not None")
if not m_type in infer_from_outshape: if not m_type in infer_from_outshape:
raise RuntimeError("Has not supported infering \ raise RuntimeError("Has not supported infering \
input shape from output shape for module/function: `{}`".format(m_type)) input shape from output shape for module/function: `{}`".format(m_type))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape) input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if input_cmask: if input_cmask:
#print("input_cmask is not None")
predecessors = self._find_predecessors(module_name) predecessors = self._find_predecessors(module_name)
for _module_name in predecessors: for _module_name in predecessors:
print("input_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, out_shape=input_cmask) self.infer_module_mask(_module_name, out_shape=input_cmask)
if output_cmask: if output_cmask:
#print("output_cmask is not None")
successors = self._find_successors(module_name) successors = self._find_successors(module_name)
for _module_name in successors: for _module_name in successors:
print("output_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, in_shape=output_cmask) self.infer_module_mask(_module_name, in_shape=output_cmask)
def infer_modules_masks(self): def infer_modules_masks(self):
...@@ -463,16 +453,19 @@ class ModelSpeedup: ...@@ -463,16 +453,19 @@ class ModelSpeedup:
""" """
for module_name in self.inferred_masks: for module_name in self.inferred_masks:
g_node = self.name_to_gnode[module_name] g_node = self.name_to_gnode[module_name]
print(module_name, g_node.op_type) _logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type)
if g_node.type == 'module': if g_node.type == 'module':
super_module, leaf_module = get_module_by_name(self.bound_model, module_name) super_module, leaf_module = get_module_by_name(self.bound_model, module_name)
m_type = g_node.op_type m_type = g_node.op_type
if not m_type in replace_module: if not m_type in replace_module:
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type)) raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)", module_name, m_type)
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name]) compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name])
setattr(super_module, module_name.split('.')[-1], compressed_module) setattr(super_module, module_name.split('.')[-1], compressed_module)
elif g_node.type == 'func': elif g_node.type == 'func':
print("Warning: Cannot replace func...") _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type)
else: else:
raise RuntimeError("Unsupported GNode type: {}".format(g_node.type)) raise RuntimeError("Unsupported GNode type: {}".format(g_node.type))
...@@ -482,10 +475,12 @@ class ModelSpeedup: ...@@ -482,10 +475,12 @@ class ModelSpeedup:
first, do mask/shape inference, first, do mask/shape inference,
second, replace modules second, replace modules
""" """
#print("start to compress") _logger.info("start to speed up the model")
_logger.info("infer module masks...")
self.infer_modules_masks() self.infer_modules_masks()
_logger.info("replace compressed modules...")
self.replace_compressed_modules() self.replace_compressed_modules()
#print("finished compressing") _logger.info("speedup done")
# resume the model mode to that before the model is speed up # resume the model mode to that before the model is speed up
if self.is_training: if self.is_training:
self.bound_model.train() self.bound_model.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment