Commit eab0da15 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Dev compression speedup (#1999)

parent ed121315
...@@ -17,7 +17,7 @@ class Mnist(torch.nn.Module): ...@@ -17,7 +17,7 @@ class Mnist(torch.nn.Module):
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x)) x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50) x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
......
import argparse
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from models.cifar10.vgg import VGG
from nni.compression.speedup.torch import ModelSpeedup
from nni.compression.torch import apply_compression_results
torch.manual_seed(0)
use_mask = False
def apoz_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
model = VGG(depth=16)
model.to(device)
model.eval()
dummy_input = torch.randn(64, 3, 32, 32)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return
def l1filter_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
model = VGG(depth=16)
model.to(device)
model.eval()
dummy_input = torch.randn(64, 3, 32, 32)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return
def fpgm_speedup(masks_file, model_checkpoint):
from fpgm_torch_mnist import Mnist
device = torch.device('cpu')
model = Mnist()
model.to(device)
model.print_conv_filter_sparsity()
dummy_input = torch.randn(64, 1, 28, 28)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(40):
out = model(dummy_input)
print('mask elapsed time: ', time.time() - start)
#print(out.size(), out)
return
else:
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(40):
out = model(dummy_input)
print('speedup elapsed time: ', time.time() - start)
#print(out.size(), out)
return
def slim_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
model = VGG(depth=19)
model.to(device)
model.eval()
dummy_input = torch.randn(64, 3, 32, 32)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser("speedup")
parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example")
parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file")
parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model")
args = parser.parse_args()
if args.example_name == 'slim':
if args.masks_file is None:
args.masks_file = 'mask_vgg19_cifar10.pth'
slim_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'fpgm':
if args.masks_file is None:
args.masks_file = 'mask.pth'
fpgm_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'l1filter':
if args.masks_file is None:
args.masks_file = 'mask_vgg16_cifar10.pth'
l1filter_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'apoz':
if args.masks_file is None:
args.masks_file = 'mask_vgg16_cifar10.pth'
apoz_speedup(args.masks_file, args.model_checkpoint)
else:
raise ValueError('unsupported example_name: {}'.format(args.example_name))
# Speed up Masked Model
*This feature is still in Alpha version.*
## Introduction
Pruning algorithms usually use weight masks to simulate the real pruning. Masks can be used
to check model performance of a specific pruning (or sparsity), but there is no real speedup.
Since model speedup is the ultimate goal of model pruning, we try to provide a tool to users
to convert a model to a smaller one based on user provided masks (the masks come from the
pruning algorithms).
There are two types of pruning. One is fine-grained pruning, it does not change the shape of weights, and input/output tensors. Sparse kernel is required to speed up a fine-grained pruned layer. The other is coarse-grained pruning (e.g., channels), shape of weights and input/output tensors usually change due to such pruning. To speed up this kind of pruning, there is no need to use sparse kernel, just replace the pruned layer with smaller one. Since the support of sparse kernels in community is limited, we only support the speedup of coarse-grained pruning and leave the support of fine-grained pruning in future.
## Design and Implementation
To speed up a model, the pruned layers should be replaced, either replaced with smaller layer for coarse-grained mask, or replaced with sparse kernel for fine-grained mask. Coarse-grained mask usually changes the shape of weights or input/output tensors, thus, we should do shape inference to check are there other unpruned layers should be replaced as well due to shape change. Therefore, in our design, there are two main steps: first, do shape inference to find out all the modules that should be replaced; second, replace the modules. The first step requires topology (i.e., connections) of the model, we use `jit.trace` to obtain the model grpah for PyTorch.
For each module, we should prepare four functions, three for shape inference and one for module replacement. The three shape inference functions are: given weight shape infer input/output shape, given input shape infer weight/output shape, given output shape infer weight/input shape. The module replacement function returns a newly created module which is smaller.
## Usage
```python
from nni.compression.speedup.torch import ModelSpeedup
# model: the model you want to speed up
# dummy_input: dummy input of the model, given to `jit.trace`
# masks_file: the mask file created by pruning algorithms
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
dummy_input = dummy_input.to(device)
start = time.time()
out = model(dummy_input)
print('elapsed time: ', time.time() - start)
```
For complete examples please refer to [the code](https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py)
NOTE: The current implementation only works on torch 1.3.1 and torchvision 0.4.2
## Limitations
Since every module requires four functions for shape inference and module replacement, this is a large amount of work, we only implemented the ones that are required by the examples. If you want to speed up your own model which cannot supported by the current implementation, you are welcome to contribute.
For PyTorch we can only replace modules, if functions in `forward` should be replaced, our current implementation does not work. One workaround is make the function a PyTorch module.
## Speedup Results of Examples
The code of these experiments can be found [here](https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py).
### slim pruner example
on one V100 GPU,
input tensor: `torch.randn(64, 3, 32, 32)`
|Times| Mask Latency| Speedup Latency |
|---|---|---|
| 1 | 0.01197 | 0.005107 |
| 2 | 0.02019 | 0.008769 |
| 4 | 0.02733 | 0.014809 |
| 8 | 0.04310 | 0.027441 |
| 16 | 0.07731 | 0.05008 |
| 32 | 0.14464 | 0.10027 |
### fpgm pruner example
on cpu,
input tensor: `torch.randn(64, 1, 28, 28)`,
too large variance
|Times| Mask Latency| Speedup Latency |
|---|---|---|
| 1 | 0.01383 | 0.01839 |
| 2 | 0.01167 | 0.003558 |
| 4 | 0.01636 | 0.01088 |
| 40 | 0.14412 | 0.08268 |
| 40 | 1.29385 | 0.14408 |
| 40 | 0.41035 | 0.46162 |
| 400 | 6.29020 | 5.82143 |
### l1filter pruner example
on one V100 GPU,
input tensor: `torch.randn(64, 3, 32, 32)`
|Times| Mask Latency| Speedup Latency |
|---|---|---|
| 1 | 0.01026 | 0.003677 |
| 2 | 0.01657 | 0.008161 |
| 4 | 0.02458 | 0.020018 |
| 8 | 0.03498 | 0.025504 |
| 16 | 0.06757 | 0.047523 |
| 32 | 0.10487 | 0.086442 |
### APoZ pruner example
on one V100 GPU,
input tensor: `torch.randn(64, 3, 32, 32)`
|Times| Mask Latency| Speedup Latency |
|---|---|---|
| 1 | 0.01389 | 0.004208 |
| 2 | 0.01628 | 0.008310 |
| 4 | 0.02521 | 0.014008 |
| 8 | 0.03386 | 0.023923 |
| 16 | 0.06042 | 0.046183 |
| 32 | 0.12421 | 0.087113 |
\ No newline at end of file
from .compressor import ModelSpeedup
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from .infer_shape import CoarseMask, ModuleMasks
replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask),
'MaxPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask)
}
def no_replace(module, mask):
"""
No need to replace
"""
return module
def replace_linear(linear, mask):
"""
Parameters
----------
linear : torch.nn.Linear
The linear module to be replace
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.Linear
The new linear module
"""
assert isinstance(mask, ModuleMasks)
assert mask.input_mask is not None
assert mask.output_mask is None
assert not mask.param_masks
index = mask.input_mask.mask_index[-1]
print(mask.input_mask.mask_index)
in_features = index.size()[0]
print('linear: ', in_features)
new_linear = torch.nn.Linear(in_features=in_features,
out_features=linear.out_features,
bias=linear.bias is not None)
new_linear.to(linear.weight.device)
new_linear.weight.data = torch.index_select(linear.weight.data, -1, index.to(linear.weight.device))
if linear.bias is not None:
new_linear.bias.data.copy_(linear.bias.data)
return new_linear
def replace_batchnorm2d(norm, mask):
"""
Parameters
----------
norm : torch.nn.BatchNorm2d
The batchnorm module to be replace
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.BatchNorm2d
The new batchnorm module
"""
assert isinstance(mask, ModuleMasks)
assert 'weight' in mask.param_masks and 'bias' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0]
print("replace batchnorm2d: ", num_features, index)
new_norm = torch.nn.BatchNorm2d(num_features=num_features,
eps=norm.eps,
momentum=norm.momentum,
affine=norm.affine,
track_running_stats=norm.track_running_stats)
# assign weights
new_norm.weight.data = torch.index_select(norm.weight.data, 0, index)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, index)
if norm.track_running_stats:
new_norm.running_mean.data = torch.index_select(norm.running_mean.data, 0, index)
new_norm.running_var.data = torch.index_select(norm.running_var.data, 0, index)
return new_norm
def replace_conv2d(conv, mask):
"""
Parameters
----------
conv : torch.nn.Conv2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.Conv2d
The new conv2d module
"""
assert isinstance(mask, ModuleMasks)
if mask.input_mask is None:
in_channels = conv.in_channels
else:
in_channels_index = mask.input_mask.mask_index[1]
in_channels = in_channels_index.size()[0]
if mask.output_mask is None:
out_channels = conv.out_channels
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=1, # currently only support groups is 1
bias=conv.bias is not None,
padding_mode=conv.padding_mode)
new_conv.to(conv.weight.device)
tmp_weight_data = tmp_bias_data = None
if mask.output_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data, 0, out_channels_index)
if conv.bias is not None:
tmp_bias_data = torch.index_select(conv.bias.data, 0, out_channels_index)
# NOTE: does not support group
if mask.input_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data if tmp_weight_data is None else tmp_weight_data,
1, in_channels_index)
assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks"
new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None:
print('final conv.bias is not None')
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import queue
import re
import torch
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
_logger = logging.getLogger(__name__)
def get_module_by_name(model, module_name):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list = module_name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
leaf_module = getattr(model, name_list[-1])
return model, leaf_module
class GNode:
"""
It is used to represent a node in model graph, in this graph a module is a node,
a function out of module (in ```forward``` function) could also be a node.
"""
def __init__(self, node_name, node_type, op_type, inputs, outputs, nodes):
"""
Parameters
----------
node_name : str
It is module name if the node is a module, it is ```scope_name.node_kind.seq``` if it is a func
node_type : str
It only has two options: `module` or `func`
op_type : str
The operation type of the module or func
inputs : list of str
All the inputs of this node, each element is debugName of one input
outputs : list of str
All the outputs of this node, each element is debugName of one output
nodes : list of node
All the trace graph nodes included in this module or func
"""
self.name = node_name
self.type = node_type
self.op_type = op_type
self.inputs = inputs
self.outputs = outputs
self.nodes = nodes
# store supplementary information for different op types
# for example, for ```view``` it stores the shape of its input and output
self.auxiliary = None
class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask
"""
def __init__(self, model, dummy_input, masks_file):
"""
Parameters
----------
model : pytorch model
The model user wants to speed up
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
masks_file : str
The path of user provided mask file
"""
self.bound_model = model
self.dummy_input = dummy_input
self.masks = torch.load(masks_file)
self.is_training = model.training
# to obtain forward graph, model should be in ```eval``` mode
if self.is_training:
model.eval()
self.trace_graph = torch.jit.trace(model, dummy_input)
if self.is_training:
model.train()
self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.g_nodes = list()
self.global_count = 0
self.name_to_gnode, self.input_to_gnode, self.output_to_gnode = self._build_graph()
def _build_index_for_gnodes(self, g_nodes):
"""
Build indexes for quick search
Parameters
----------
g_nodes : list of GNode
All the g_node in processed model graph
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
name_to_gnode = dict()
input_to_gnode = dict()
output_to_gnode = dict()
for node in g_nodes:
name_to_gnode[node.name] = node
for _input in node.inputs:
if _input in input_to_gnode:
input_to_gnode[_input].append(node)
else:
input_to_gnode[_input] = [node]
for output in node.outputs:
assert not output in output_to_gnode, \
"One output cannot be generated by multiple nodes"
output_to_gnode[output] = node
return name_to_gnode, input_to_gnode, output_to_gnode
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a GNode.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
Returns
-------
GNode
the expanded non-prim node in GNode format
"""
# TODO: scope name could be empty
node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)])
#print('node_name: ', node_name)
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
node_queue = queue.Queue()
node_queue.put(node)
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
#print("predecessor_node: ", predecessor_node)
if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.append(input_name)
else:
inputs.append(input_name)
for output in node.outputs():
outputs.append(output.debugName())
g_node = GNode(node_name, 'func', op_type, inputs, outputs, node_group)
return g_node
def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input = None
for _input in node.inputs():
t_input = _input
break
t_output = node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
def _build_graph(self):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to GNode, Third, extract all functions to convert
to GNode.
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
graph = self.trace_graph.graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
#print(graph)
# build output mapping, from output debugName to its node
output_to_node = dict()
# build input mapping, from input debugName to its node
input_to_node = dict()
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = dict()
# module name to its type
module_to_type = dict()
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes = dict()
graph_inputs = list()
graph_outputs = list()
for _input in graph.inputs():
graph_inputs.append(_input.debugName())
for output in graph.outputs():
graph_outputs.append(output.debugName())
for node in graph.nodes():
# populate output_to_node and input_to_node
for output in node.outputs():
output_name = output.debugName()
output_to_node[output_name] = node
for _input in node.inputs():
input_name = _input.debugName()
input_to_node[input_name] = node
scope_name = node.scopeName() # example: scope_name, 'MyCell/Linear[linear]'
module_name_slices = re.findall(r'\[(.*?)\]', scope_name)
module_name = '.'.join(module_name_slices)
# if module_name is empty, it is not a module
if module_name == '':
if scope_name == '':
continue
else:
if scope_name in func_to_nodes:
func_to_nodes[scope_name].append(node)
else:
func_to_nodes[scope_name] = [node]
else:
scope_slice = scope_name.split('/')[-1]
module_type = scope_slice.split('[')[0]
module_to_type[module_name] = module_type
if module_name in module_to_nodes:
module_to_nodes[module_name].append(node)
else:
module_to_nodes[module_name] = [node]
# construct GNode from module
for module_name, nodes in module_to_nodes.items():
inputs = set()
outputs = set()
for node in nodes:
for output in node.outputs():
outputs.add(output.debugName())
for _input in node.inputs():
inputs.add(_input.debugName())
m_inputs = list()
m_outputs = list()
for output in outputs:
# TODO: one input could be the input of multiple nodes
if not output in input_to_node and output in graph_outputs:
m_outputs.append(output)
elif not input_to_node[output] in nodes:
m_outputs.append(output)
for _input in inputs:
if not _input in output_to_node and _input in graph_inputs:
m_inputs.append(_input)
elif not output_to_node[_input] in nodes:
m_inputs.append(_input)
print("module node_name: ", module_name)
if module_name == '':
for n in nodes:
print(n)
g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes)
self.g_nodes.append(g_node)
# each scope_name may have multiple funcs, we split them and create GNode for each of them
for scope_name, nodes in func_to_nodes.items():
# extract non prim:: nodes
non_prim_nodes = list()
for node in nodes:
if not node.kind().startswith('prim::'):
non_prim_nodes.append(node)
# for each non prim node, expand it has a GNode
for node in non_prim_nodes:
g_node = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node)
self.g_nodes.append(g_node)
# get shape infor for view (aten::view) func
if g_node.op_type == 'aten::view':
g_node.auxiliary = self._extract_shape_info(node)
# build index for g_nodes
name_to_gnode, input_to_gnode, output_to_gnode = self._build_index_for_gnodes(self.g_nodes)
return name_to_gnode, input_to_gnode, output_to_gnode
def _find_predecessors(self, module_name):
"""
Find predecessor GNode of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's predecessor
"""
predecessors = []
for _input in self.name_to_gnode[module_name].inputs:
if not _input in self.output_to_gnode:
print(_input)
if not _input in self.output_to_gnode:
# TODO: check _input which does not have node
print("output with no gnode: ", _input)
else:
g_node = self.output_to_gnode[_input]
predecessors.append(g_node.name)
return predecessors
def _find_successors(self, module_name):
"""
Find successor GNodes of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's successor
"""
successors = []
for output in self.name_to_gnode[module_name].outputs:
assert output in self.input_to_gnode, "No gnode with input {}".format(output)
g_nodes = self.input_to_gnode[output]
for g_node in g_nodes:
successors.append(g_node.name)
return successors
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None):
"""
Infer input shape / output shape based on the module's weight mask / input shape / output shape.
For a module:
Infer its input and output shape from its weight mask
Infer its output shape from its input shape
Infer its input shape from its output shape
If its input shape is changed, continue infering its predecessors
If its output shape is changed, continue infering its successors
Parameters
----------
module_name : str
The name of the GNode
mask : tensor of mask or ModuleMasks
Mask of the weights in this GNode (i.e., module)
in_shape : ModuleMasks
Input shape of this GNode
out_shape : ModuleMasks
Output shape of this GNode
"""
input_cmask = output_cmask = None
if module_name in self.inferred_masks:
module_masks = self.inferred_masks[module_name]
else:
module_masks = ModuleMasks(module_name)
self.inferred_masks[module_name] = module_masks
m_type = self.name_to_gnode[module_name].op_type
print("infer_module_mask: {}, module type: {}".format(module_name, m_type))
if mask is not None:
#print("mask is not None")
if not m_type in infer_from_mask:
raise RuntimeError("Has not supported infering \
input/output shape from mask for module/function: `{}`".format(m_type))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
#print("in_shape is not None")
if not m_type in infer_from_inshape:
raise RuntimeError("Has not supported infering \
output shape from input shape for module/function: `{}`".format(m_type))
if m_type == 'aten::view':
output_cmask = infer_from_inshape[m_type](module_masks,
in_shape,
self.name_to_gnode[module_name].auxiliary)
else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None:
#print("out_shape is not None")
if not m_type in infer_from_outshape:
raise RuntimeError("Has not supported infering \
input shape from output shape for module/function: `{}`".format(m_type))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if input_cmask:
#print("input_cmask is not None")
predecessors = self._find_predecessors(module_name)
for _module_name in predecessors:
print("input_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, out_shape=input_cmask)
if output_cmask:
#print("output_cmask is not None")
successors = self._find_successors(module_name)
for _module_name in successors:
print("output_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, in_shape=output_cmask)
def infer_modules_masks(self):
"""
Do shape inference of involved modules, including the shape of weights, inputs, output
"""
for module_name, mask in self.masks.items():
self.infer_module_mask(module_name, mask=mask)
def replace_compressed_modules(self):
"""
Replace all the modules that have changed (weights/inputs/output) shape.
The new module is created using the same arguments of the to-be-replaced module,
and correctly inherits its weights.
NOTE: ```func``` type cannot be replaced as it is not a module, thus, one limitation
is that ```func``` should be not required to be replaced.
"""
for module_name in self.inferred_masks:
g_node = self.name_to_gnode[module_name]
print(module_name, g_node.op_type)
if g_node.type == 'module':
super_module, leaf_module = get_module_by_name(self.bound_model, module_name)
m_type = g_node.op_type
if not m_type in replace_module:
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type))
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name])
setattr(super_module, module_name.split('.')[-1], compressed_module)
elif g_node.type == 'func':
print("Warning: Cannot replace func...")
else:
raise RuntimeError("Unsupported GNode type: {}".format(g_node.type))
def speedup_model(self):
"""
There are basically two steps:
first, do mask/shape inference,
second, replace modules
"""
#print("start to compress")
self.infer_modules_masks()
self.replace_compressed_modules()
#print("finished compressing")
# resume the model mode to that before the model is speed up
if self.is_training:
self.bound_model.train()
else:
self.bound_model.eval()
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
For each operation or module, there are two functions.
One is given output shape, infer its input shape and initialization parameters (e.g., weight's shape)
The other is given input shape, infer its output shape and initialization parameters (e.g., weight's shape)
"""
import torch
class CoarseMask:
"""
Coarse grained mask for a given tensor, here tensor could be weights,
input tensor, or output tensor
"""
def __init__(self, num_dim):
"""
Parameters
----------
num_dim : int
The number of dimensions of the tensor that will be masked
"""
self.mask_index = [None for _ in range(num_dim)]
def add_index_mask(self, dim, index):
"""
Add mask for the specified dimension
Parameters
----------
dim : int
The dimension to add mask
index : tensor
The mask for this dimension, its a 1 dimension tensor which specifies
the index of the elements that are not pruned
"""
self.mask_index[dim] = index
@staticmethod
def merge_index(index_a, index_b):
"""
Parameters
----------
index_a : tensor
One index (1-dimension) tensor
index_b : tensor
The other index (1-dimension) tensor
Returns
-------
tensor
The merged index (1-dimension) tensor
"""
s = set()
for num in index_a:
s.add(num)
for num in index_b:
s.add(num)
return torch.tensor(sorted(s))
def merge(self, cmask):
"""
Merge another CoarseMask
Parameters
----------
cmask : CoarseMask
Another CoarseMask to merge
Returns
-------
list
The member variable ```mask_index```
"""
assert isinstance(cmask, CoarseMask)
assert len(self.mask_index) == len(cmask.mask_index), \
"Only masks with the same number of dimensions can be merged"
for i, index in enumerate(self.mask_index):
if index is None:
self.mask_index[i] = cmask.mask_index[i]
elif cmask.mask_index[i] is not None:
self.mask_index[i] = CoarseMask.merge_index(self.mask_index[i],
cmask.mask_index[i])
return self.mask_index
class ModuleMasks:
"""
The masks of a module, including the masks for weights, inputs, output
"""
def __init__(self, module_name):
"""
Parameters
----------
module_name : str
The name of the module or function
"""
self.module_name = module_name
self.param_masks = dict()
self.input_mask = None
self.output_mask = None
def set_param_masks(self, name, mask):
"""
Parameters
----------
name : str
The name of the weight
mask : CoarseMask
The mask for this weight
"""
self.param_masks[name] = mask
def set_input_mask(self, mask):
"""
Parameters
----------
mask : CoarseMask
The mask for input
"""
self.input_mask = mask
def set_output_mask(self, mask):
"""
Parameters
----------
mask : CoarseMask
The mask for output
"""
self.output_mask = mask
"""
Infer input and output shape of a module/function from its weight mask
"""
infer_from_mask = {
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_mask(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask)
}
"""
Infer output and weight shape of a module/function from its input shape
"""
infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'AvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::size': lambda module_masks, mask: size_inshape(module_masks, mask),
'aten::view': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape),
'Linear': lambda module_masks, mask: linear_inshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask)
}
"""
Infer input and weight shape of a module/function from its output shape
"""
infer_from_outshape = {
'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask)
}
def batchnorm2d_inshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
weight_cmask = CoarseMask(num_dim=1)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', weight_cmask)
return mask
def linear_inshape(module_masks, mask):
"""
Coarse grained input mask does not change the shape of weights and output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the linear
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor, ```None``` means shape of output tensor is not changed
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[0] is None
assert module_masks.input_mask is None
module_masks.set_input_mask(mask)
return None
def view_inshape(module_masks, mask, shape):
"""
This is a limited support
TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not
included in module, thus, cannot be replaced by our framework.
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the ```view``` op
mask : CoarseMask
The mask of its input tensor
shape : dict
Original shape of its input and output tensors
Returns
-------
CoarseMask
The mask of its output tensor
"""
# NOTE: the case constrained by the following four asserts
assert shape['in_shape'][0] == shape['out_shape'][0]
assert len(shape['in_shape']) == 4
assert len(shape['out_shape']) == 2
assert shape['out_shape'][1] == shape['in_shape'][1]*shape['in_shape'][2]*shape['in_shape'][3]
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
assert module_masks.input_mask is None
module_masks.set_input_mask(mask)
output_cmask = CoarseMask(num_dim=2)
index = []
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index))
module_masks.set_output_mask(output_cmask)
return output_cmask
def size_inshape(module_masks, mask):
"""
No need to do anything for this ```size``` op
"""
return None
def maxpool2d_inshape(module_masks, mask):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the maxpool2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
assert module_masks.input_mask is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def relu_inshape(module_masks, mask):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the relu
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
# TODO: double check this assert, is it possible that a module is passed twice
assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def batchnorm2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : dict
The mask of its weights, from the user provided mask file
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
assert 'weight' in mask and 'bias' in mask
sum_mask = mask['weight'] + mask['bias']
nonzero_index = torch.nonzero(sum_mask, as_tuple=True)[0]
# infer shape of parameters
param_cmask = CoarseMask(num_dim=1)
param_cmask.add_index_mask(dim=0, index=nonzero_index)
module_masks.set_param_masks('weight', param_cmask)
module_masks.set_param_masks('bias', param_cmask)
# infer shape of input tensor
input_cmask = CoarseMask(num_dim=4)
input_cmask.add_index_mask(dim=1,
index=torch.nonzero(mask['weight'], as_tuple=True)[0])
module_masks.set_input_mask(input_cmask)
# infer shape of output tensor
output_cmask = CoarseMask(num_dim=4)
output_cmask.add_index_mask(dim=1, index=nonzero_index)
module_masks.set_output_mask(output_cmask)
return input_cmask, output_cmask
def conv2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : dict
The mask of its weights, from the user provided mask file
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
def convert_to_coarse_mask(mask):
"""
Parameters
----------
mask : dict
Weight mask from user provided mask file
Returns
-------
LongTensor, CoarseMask, CoarseMask
Index of the masked dimension, weight mask, bias mask
"""
assert 'weight' in mask
assert isinstance(mask['weight'], torch.Tensor)
cmask = None
weight_mask = mask['weight']
shape = weight_mask.size()
ones = torch.ones(shape[1:]).to(weight_mask.device)
zeros = torch.zeros(shape[1:]).to(weight_mask.device)
index = []
for i in range(shape[0]):
if torch.all(torch.eq(weight_mask[i], ones)):
index.append(i)
elif torch.all(torch.eq(weight_mask[i], zeros)):
continue
else:
index = None
break
if index is None:
return None, None, None
else:
index = torch.LongTensor(index).to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=index)
bias_cmask = None
if 'bias' in mask and mask['bias'] is not None:
bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0]
assert torch.all(torch.eq(index, bias_index)), \
"bias mask should be consistent with weight mask"
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask
index, weight_cmask, bias_cmask = convert_to_coarse_mask(mask)
if index is None:
# TODO: fine grained mask speedup
return None, None
# deal with coarse grain mask
if 'weight' in module_masks.param_masks:
module_masks.param_masks['weight'].merge(weight_cmask)
module_masks.param_masks['bias'].merge(bias_cmask)
else:
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
output_cmask = CoarseMask(num_dim=4)
output_cmask.add_index_mask(dim=1, index=index)
if module_masks.output_mask is None:
module_masks.set_output_mask(output_cmask)
else:
module_masks.output_mask.merge(output_cmask)
return None, module_masks.output_mask
def conv2d_inshape(module_masks, mask):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert module_masks.input_mask is None
module_masks.set_input_mask(mask)
return None
def conv2d_outshape(module_masks, mask):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its output tensor
Returns
-------
CoarseMask
The mask of its input tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
if module_masks.output_mask is not None:
assert isinstance(module_masks.output_mask, CoarseMask)
# set shape of output
mask = module_masks.output_mask.merge(mask)
else:
module_masks.output_mask = mask
# infer shape of parameters
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
# input shape is not changed
return None
\ No newline at end of file
...@@ -6,3 +6,4 @@ from .pruners import * ...@@ -6,3 +6,4 @@ from .pruners import *
from .weight_rank_filter_pruners import * from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import * from .activation_rank_filter_pruners import *
from .quantizers import * from .quantizers import *
from .apply_compression import apply_compression_results
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from .compressor import Pruner
logger = logging.getLogger('torch apply compression')
def apply_compression_results(model, masks_file):
"""
Apply the masks from ```masks_file``` to the model
Parameters
----------
model : torch.nn.module
The model to be compressed
masks_file : str
The path of the mask file
"""
apply_comp = ApplyCompression(model, masks_file)
apply_comp.compress()
class ApplyCompression(Pruner):
"""
This class is not to generate masks, but applying existing masks
"""
def __init__(self, model, masks_file):
"""
Parameters
----------
model : torch.nn.module
Model to be masked
masks_file : str
The path of user provided mask file
"""
self.bound_model = model
self.masks = torch.load(masks_file)
for module_name in self.masks:
print('module_name: ', module_name)
config_list = self._build_config()
super().__init__(model, config_list)
def _build_config(self):
op_names = []
for module_name in self.masks:
op_names.append(module_name)
return [{'sparsity': 1, 'op_types': ['default', 'BatchNorm2d'], 'op_names': op_names}]
def calc_mask(self, layer, config, **kwargs):
"""
Directly return the corresponding mask
Parameters
----------
layer : LayerInfo
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
dict
Mask of the layer
"""
assert layer.name in self.masks
return self.masks[layer.name]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment