Unverified Commit e6817d22 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Support the Resnet/Squeezenet/Mobilenet for speedup (#2579)

parent 3d4f122a
...@@ -18,6 +18,16 @@ ...@@ -18,6 +18,16 @@
.. autoclass:: nni.compression.torch.utils.shape_dependency.ChannelDependency .. autoclass:: nni.compression.torch.utils.shape_dependency.ChannelDependency
:members: :members:
.. autoclass:: nni.compression.torch.utils.mask_conflict.MaskConflict .. autoclass:: nni.compression.torch.utils.shape_dependency.GroupDependency
:members: :members:
.. autoclass:: nni.compression.torch.utils.mask_conflict.CatMaskPadding
:members:
.. autoclass:: nni.compression.torch.utils.mask_conflict.GroupMaskConflict
:members:
.. autoclass:: nni.compression.torch.utils.mask_conflict.ChannelMaskConflict
:members:
``` ```
...@@ -116,8 +116,6 @@ Set 12,layer4.1.conv1 ...@@ -116,8 +116,6 @@ Set 12,layer4.1.conv1
When the masks of different layers in a model have conflict (for example, assigning different sparsities for the layers that have channel dependency), we can fix the mask conflict by MaskConflict. Specifically, the MaskConflict loads the masks exported by the pruners(L1FilterPruner, etc), and check if there is mask conflict, if so, MaskConflict sets the conflicting masks to the same value. When the masks of different layers in a model have conflict (for example, assigning different sparsities for the layers that have channel dependency), we can fix the mask conflict by MaskConflict. Specifically, the MaskConflict loads the masks exported by the pruners(L1FilterPruner, etc), and check if there is mask conflict, if so, MaskConflict sets the conflicting masks to the same value.
``` ```
from nni.compression.torch.utils.mask_conflict import MaskConflict from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
mc = MaskConflict('./resnet18_mask', net, data) fixed_mask = fix_mask_conflict('./resnet18_mask', net, data)
mc.fix_mask_conflict()
mc.export('./resnet18_fixed_mask')
``` ```
\ No newline at end of file
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy
CLASSTYPE_KIND = 'ClassType' CLASSTYPE_KIND = 'ClassType'
GETATTR_KIND = 'prim::GetAttr' GETATTR_KIND = 'prim::GetAttr'
CAT_KIND = 'aten::cat'
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -236,6 +237,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -236,6 +237,7 @@ class TorchModuleGraph(TorchGraph):
super().__init__(model, dummy_input, traced_model) super().__init__(model, dummy_input, traced_model)
self.global_count = 0 self.global_count = 0
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
self._extract_auxiliary_info()
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node,
module_type): module_type):
...@@ -364,6 +366,58 @@ class TorchModuleGraph(TorchGraph): ...@@ -364,6 +366,58 @@ class TorchModuleGraph(TorchGraph):
node_group, inputs=inputs, outputs=outputs) node_group, inputs=inputs, outputs=outputs)
return nodepy return nodepy
def _extract_cat_info(self, node_group, cpp_node):
"""
Extract the detail information of the cat operation,
such the order of the input tensor, the shape of each
input tensor, the output shape, and the cat dimension.
Parameters
----------
node_group : NodePyGroup
cpp_node: torch._C.Node
It should be ```aten::cat``` node
Returns
-------
dict
Include auxiliary information for the cat operation.
This dict objec has four keys: 'cat_dim', 'out_shape',
'in_order' and 'in_shape'. cat_dim is the dimension of
the cat operation to concat the input tensors. out_shape
is the shape of the output tensor of the cat operation.
in_order is an ordered list which contains the corresponding
parent operaion nodes of the input tensors. in_shape is also
an ordered list that contains the input shapes of the input
tensor.
"""
# only suport the cat operation
assert cpp_node.kind() == CAT_KIND
cat_info = {}
# get the shape of the output tensor
t_output = cpp_node.output()
out_shape = t_output.type().sizes()
cat_info['out_shape'] = out_shape
# get the cat dimension
inputs = cpp_node.inputs()
cat_dim = list(inputs)[1].toIValue()
cat_info['cat_dim'] = cat_dim
# get the order of the input tensors
# To get the order of the input tensors, we need
# to be aware of the topology of the model, which
# means we should extract the auxiliary information
# after the build_index function.
input_order = []
list_construct_cpp = list(cpp_node.inputs())[0].node()
input_tensors = list(list_construct_cpp.inputs())
for _tensor in input_tensors:
debug_name = _tensor.debugName()
input_order.append(self.output_to_node[debug_name].unique_name)
cat_info['in_order'] = input_order
input_shapes = [t.type().sizes() for t in input_tensors]
cat_info['in_shape'] = input_shapes
return cat_info
def _extract_shape_info(self, node): def _extract_shape_info(self, node):
""" """
Extract the shape information of ```aten::view``` node Extract the shape information of ```aten::view``` node
...@@ -541,8 +595,8 @@ class TorchModuleGraph(TorchGraph): ...@@ -541,8 +595,8 @@ class TorchModuleGraph(TorchGraph):
node, nodes, input_to_node, output_to_node, 'func') node, nodes, input_to_node, output_to_node, 'func')
nodes_py.nodes_op.append(node_group) nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func # get shape infor for view (aten::view) func
if node_group.op_type in ['aten::view', 'aten::flatten']: # if node_group.op_type in ['aten::view', 'aten::flatten']:
node_group.auxiliary = self._extract_shape_info(node) # node_group.auxiliary = self._extract_shape_info(node)
for node in graph.outputs(): # Create sink nodes for output ops for node in graph.outputs(): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output') node_py = NodePyIO(node, 'output')
...@@ -552,6 +606,26 @@ class TorchModuleGraph(TorchGraph): ...@@ -552,6 +606,26 @@ class TorchModuleGraph(TorchGraph):
# build index # build index
return self._build_index(self.nodes_py.nodes_op) return self._build_index(self.nodes_py.nodes_op)
def _extract_auxiliary_info(self):
"""
Extract the auxiliary information for the nodegroups
if necessary. For example, view/flatten operations may
need the shape of the input tensor and output tensor.
"""
# extract the input & output shape for the view and flatten
for node_group in self.nodes_py.nodes_op:
if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
# get shape infor for view (aten::view) func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
node_group.auxiliary = self._extract_shape_info(cpp_node)
elif node_group.op_type == CAT_KIND:
# get the detail information for cat func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
node_group.auxiliary = self._extract_cat_info(
node_group, cpp_node)
def find_predecessors(self, unique_name): def find_predecessors(self, unique_name):
""" """
Find predecessor node of the given node Find predecessor node of the given node
......
...@@ -14,7 +14,11 @@ replace_module = { ...@@ -14,7 +14,11 @@ replace_module = {
'AvgPool2d': lambda module, mask: no_replace(module, mask), 'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask), 'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask), 'ReLU': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask) 'ReLU6': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask),
'Dropout': lambda module, mask: no_replace(module, mask),
'Dropout2d': lambda module, mask: no_replace(module, mask),
'Dropout3d': lambda module, mask: no_replace(module, mask)
} }
def no_replace(module, mask): def no_replace(module, mask):
...@@ -111,6 +115,7 @@ def replace_conv2d(conv, mask): ...@@ -111,6 +115,7 @@ def replace_conv2d(conv, mask):
else: else:
out_channels_index = mask.output_mask.mask_index[1] out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0] out_channels = out_channels_index.size()[0]
_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels) _logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels, new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -118,21 +123,45 @@ def replace_conv2d(conv, mask): ...@@ -118,21 +123,45 @@ def replace_conv2d(conv, mask):
stride=conv.stride, stride=conv.stride,
padding=conv.padding, padding=conv.padding,
dilation=conv.dilation, dilation=conv.dilation,
groups=1, # currently only support groups is 1 groups=conv.groups,
bias=conv.bias is not None, bias=conv.bias is not None,
padding_mode=conv.padding_mode) padding_mode=conv.padding_mode)
new_conv.to(conv.weight.device) new_conv.to(conv.weight.device)
tmp_weight_data = tmp_bias_data = None tmp_weight_data = tmp_bias_data = None
if mask.output_mask is not None: if mask.output_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data, 0, out_channels_index) tmp_weight_data = torch.index_select(conv.weight.data, 0, out_channels_index)
if conv.bias is not None: if conv.bias is not None:
tmp_bias_data = torch.index_select(conv.bias.data, 0, out_channels_index) tmp_bias_data = torch.index_select(conv.bias.data, 0, out_channels_index)
# NOTE: does not support group else:
tmp_weight_data = conv.weight.data
# For the convolutional layers that have more than one group
# we need to copy the weight group by group, because the input
# channal is also divided into serveral groups and each group
# filter may have different input channel indexes.
input_step = int(conv.in_channels / conv.groups)
in_channels_group = int(in_channels / conv.groups)
filter_step = int(out_channels / conv.groups)
if mask.input_mask is not None: if mask.input_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data if tmp_weight_data is None else tmp_weight_data, for groupid in range(conv.groups):
1, in_channels_index) start = groupid * input_step
assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks" end = (groupid + 1) * input_step
new_conv.weight.data.copy_(tmp_weight_data) current_input_index = list(filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
# shift the global index into the group index
current_input_index = [x-start for x in current_input_index]
# if the groups is larger than 1, the input channels of each
# group should be pruned evenly.
assert len(current_input_index) == in_channels_group, \
'Input channels of each group are not pruned evenly'
current_input_index = torch.tensor(current_input_index).to(tmp_weight_data.device) # pylint: disable=not-callable
f_start = groupid * filter_step
f_end = (groupid + 1) * filter_step
new_conv.weight.data[f_start:f_end] = torch.index_select(tmp_weight_data[f_start:f_end], 1, current_input_index)
else:
new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None: if conv.bias is not None:
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data) new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv return new_conv
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import logging import logging
import torch import torch
from nni._graph_utils import build_module_graph from nni._graph_utils import build_module_graph
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
from .compress_modules import replace_module from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
...@@ -53,9 +54,10 @@ class ModelSpeedup: ...@@ -53,9 +54,10 @@ class ModelSpeedup:
self.bound_model = model self.bound_model = model
self.masks = torch.load(masks_file, map_location) self.masks = torch.load(masks_file, map_location)
self.inferred_masks = dict() # key: module_name, value: ModuleMasks self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.dummy_input = dummy_input
self.torch_graph = build_module_graph(model, dummy_input) self.torch_graph = build_module_graph(model, dummy_input)
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None): def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, out_shape=None):
""" """
Infer input shape / output shape based on the module's weight mask / input shape / output shape. Infer input shape / output shape based on the module's weight mask / input shape / output shape.
...@@ -71,6 +73,8 @@ class ModelSpeedup: ...@@ -71,6 +73,8 @@ class ModelSpeedup:
---------- ----------
module_name : str module_name : str
The name of the node The name of the node
last_module : str
The name of last visited node
mask : tensor of mask or ModuleMasks mask : tensor of mask or ModuleMasks
Mask of the weights in this node (i.e., module) Mask of the weights in this node (i.e., module)
in_shape : ModuleMasks in_shape : ModuleMasks
...@@ -100,10 +104,17 @@ class ModelSpeedup: ...@@ -100,10 +104,17 @@ class ModelSpeedup:
raise RuntimeError( raise RuntimeError(
"Has not supported infering output shape from input shape for module/function: `{}`, {}" "Has not supported infering output shape from input shape for module/function: `{}`, {}"
.format(m_type, module_name)) .format(m_type, module_name))
if m_type in ['aten::view', 'aten::flatten']: if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
output_cmask = infer_from_inshape[m_type](module_masks, output_cmask = infer_from_inshape[m_type](module_masks,
in_shape, in_shape,
self.torch_graph.name_to_node[module_name].auxiliary) self.torch_graph.name_to_node[module_name].auxiliary)
elif m_type in ['aten::cat']:
# To calculate the mask for concat operation, the output shape
# , cat dimension, and the order of the input parameters.
output_cmask = infer_from_inshape[m_type](module_masks,
in_shape,
self.torch_graph.name_to_node[module_name].auxiliary,
last_module)
else: else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape) output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None: if out_shape is not None:
...@@ -117,18 +128,19 @@ class ModelSpeedup: ...@@ -117,18 +128,19 @@ class ModelSpeedup:
if input_cmask: if input_cmask:
predecessors = self.torch_graph.find_predecessors(module_name) predecessors = self.torch_graph.find_predecessors(module_name)
for _module_name in predecessors: for _module_name in predecessors:
self.infer_module_mask(_module_name, out_shape=input_cmask) self.infer_module_mask(_module_name, module_name, out_shape=input_cmask)
if output_cmask: if output_cmask:
successors = self.torch_graph.find_successors(module_name) successors = self.torch_graph.find_successors(module_name)
for _module_name in successors: for _module_name in successors:
self.infer_module_mask(_module_name, in_shape=output_cmask) self.infer_module_mask(_module_name, module_name, in_shape=output_cmask)
def infer_modules_masks(self): def infer_modules_masks(self):
""" """
Do shape inference of involved modules, including the shape of weights, inputs, output Do shape inference of involved modules, including the shape of weights, inputs, output
""" """
for module_name, mask in self.masks.items(): for module_name, mask in self.masks.items():
self.infer_module_mask(module_name, mask=mask) _logger.debug('Start mask inference from %s', module_name)
self.infer_module_mask(module_name, None, mask=mask)
def replace_compressed_modules(self): def replace_compressed_modules(self):
""" """
...@@ -144,19 +156,20 @@ class ModelSpeedup: ...@@ -144,19 +156,20 @@ class ModelSpeedup:
_logger.debug("replace %s, in %s type, with op_type %s", _logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type) module_name, g_node.type, g_node.op_type)
if g_node.type == 'module': if g_node.type == 'module':
super_module, leaf_module = get_module_by_name(self.bound_model, module_name) super_module, leaf_module = get_module_by_name(self.bound_model, g_node.name)
m_type = g_node.op_type m_type = g_node.op_type
if not m_type in replace_module: if not m_type in replace_module:
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type)) raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)", module_name, m_type) _logger.info("replace module (name: %s, op_type: %s)", g_node.name, m_type)
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name]) compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name])
setattr(super_module, module_name.split('.')[-1], compressed_module) setattr(super_module, g_node.name.split('.')[-1], compressed_module)
elif g_node.type == 'func': elif g_node.type == 'func':
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type", _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type) module_name, g_node.op_type)
else: else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type)) raise RuntimeError("Unsupported node type: {}".format(g_node.type))
def speedup_model(self): def speedup_model(self):
""" """
There are basically two steps: There are basically two steps:
...@@ -165,6 +178,8 @@ class ModelSpeedup: ...@@ -165,6 +178,8 @@ class ModelSpeedup:
""" """
training = self.bound_model.training training = self.bound_model.training
_logger.info("start to speed up the model") _logger.info("start to speed up the model")
_logger.info("fix the mask conflict of the interdependent layers")
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
_logger.info("infer module masks...") _logger.info("infer module masks...")
self.infer_modules_masks() self.infer_modules_masks()
_logger.info("replace compressed modules...") _logger.info("replace compressed modules...")
......
...@@ -8,11 +8,13 @@ The other is given input shape, infer its output shape and initialization parame ...@@ -8,11 +8,13 @@ The other is given input shape, infer its output shape and initialization parame
import torch import torch
class CoarseMask: class CoarseMask:
""" """
Coarse grained mask for a given tensor, here tensor could be weights, Coarse grained mask for a given tensor, here tensor could be weights,
input tensor, or output tensor input tensor, or output tensor
""" """
def __init__(self, num_dim): def __init__(self, num_dim):
""" """
Parameters Parameters
...@@ -50,13 +52,26 @@ class CoarseMask: ...@@ -50,13 +52,26 @@ class CoarseMask:
------- -------
tensor tensor
The merged index (1-dimension) tensor The merged index (1-dimension) tensor
Note that: the output tensor will be moved
to the same device as index_a.
""" """
device = index_a.device
s = set() s = set()
for num in index_a: for num in index_a.tolist():
# we need to transfer the tensor to list here
# first, directly traversing the tensor by for
# loop will return the list of tensor(x) object,
# even the value are the same, but they are different
# tensor objects, so the set will contains multiple
# tensor objects that has the same value. For example
# for num in torch.ones(2):
# s.add(num)
# s will be {tensor(1), tensor(1)}
s.add(num) s.add(num)
for num in index_b: for num in index_b.tolist():
s.add(num) s.add(num)
return torch.tensor(sorted(s)) # pylint: disable=not-callable # move the output tensor to the same device with index_a
return torch.tensor(sorted(s)).to(device) # pylint: disable=not-callable
def merge(self, cmask): def merge(self, cmask):
""" """
...@@ -86,10 +101,65 @@ class CoarseMask: ...@@ -86,10 +101,65 @@ class CoarseMask:
def __repr__(self): def __repr__(self):
return 'mask_index: {}'.format(self.mask_index) return 'mask_index: {}'.format(self.mask_index)
def eq_on_dim(self, other, dim):
assert isinstance(other, CoarseMask)
if self.mask_index[dim] is None and other.mask_index[dim] is None:
return True
elif isinstance(self.mask_index[dim], torch.Tensor) \
and isinstance(other.mask_index[dim], torch.Tensor):
return torch.equal(self.mask_index[dim], other.mask_index[dim])
else:
return False
def __eq__(self, other):
assert isinstance(other, CoarseMask)
if len(self.mask_index) != len(other.mask_index):
return False
for i in range(len(self.mask_index)):
if not self.eq_on_dim(other, i):
return False
return True
def __lt__(self, other):
"""
Judge if the mask is a subset of another CoarseMask.
"""
assert isinstance(other, CoarseMask)
for dim, _ in enumerate(self.mask_index):
# if self has more dimensions
if dim >= len(other.mask_index):
return False
if self.mask_index[dim] is None:
# if no mask on this dimension, then we have less
# masks then the other CoraseMask.
continue
elif other.mask_index[dim] is None:
return False
else:
s1 = set(self.mask_index[dim].tolist())
s2 = set(other.mask_index[dim].tolist())
if not s1 < s2:
return False
return True
def __le__(self, other):
"""
Return if self's mask is less or equal to other's mask.
"""
assert isinstance(other, CoarseMask)
if self.__lt__(other) or self.__eq__(other):
return True
return False
def __ne__(self, other):
return not self.__eq__(other)
class ModuleMasks: class ModuleMasks:
""" """
The masks of a module, including the masks for weights, inputs, output The masks of a module, including the masks for weights, inputs, output
""" """
def __init__(self, module_name): def __init__(self, module_name):
""" """
Parameters Parameters
...@@ -136,6 +206,7 @@ class ModuleMasks: ...@@ -136,6 +206,7 @@ class ModuleMasks:
self.input_mask, self.output_mask, self.param_masks self.input_mask, self.output_mask, self.param_masks
) )
""" """
Infer input and output shape of a module/function from its weight mask Infer input and output shape of a module/function from its weight mask
""" """
...@@ -149,18 +220,27 @@ Infer output and weight shape of a module/function from its input shape ...@@ -149,18 +220,27 @@ Infer output and weight shape of a module/function from its input shape
""" """
infer_from_inshape = { infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask), 'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::adaptive_avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'AvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'AvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::size': lambda module_masks, mask: size_inshape(module_masks, mask), 'aten::size': lambda module_masks, mask: size_inshape(module_masks, mask),
'aten::view': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape), 'aten::view': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape),
'aten::flatten': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape), # support only start_dim=1 'aten::reshape': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape),
# support only start_dim=1
'aten::flatten': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape),
'Linear': lambda module_masks, mask: linear_inshape(module_masks, mask), 'Linear': lambda module_masks, mask: linear_inshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask) 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask),
'aten::add_': lambda module_masks, mask: add_inshape(module_masks, mask),
'aten::add': lambda module_mask, mask: add_inshape(module_mask, mask),
'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited),
'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask)
} }
""" """
...@@ -170,6 +250,120 @@ infer_from_outshape = { ...@@ -170,6 +250,120 @@ infer_from_outshape = {
'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask) 'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask)
} }
def dropout_inshape(module_masks, mask):
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return module_masks.output_mask
# if alreay visited
assert module_masks.input_mask <= mask
if module_masks.input_mask == mask:
return None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return module_masks.output_mask
def cat_inshape(module_masks, mask, cat_info, last_visited):
"""
Inference the output mask of the cat operation from the
input mask.
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
cat_info: dict
Dict object that records the necessary information
of cat operation, such as the order of the input
tensors.
last_visited: str
The unique_name of the last visited node group.
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
out_shape = cat_info['out_shape']
cat_dim = cat_info['cat_dim']
in_order = cat_info['in_order']
in_shape = cat_info['in_shape']
if module_masks.output_mask is None:
# First visit to this cat node
# initialize the mask based on
# the number of the output channel.
output_mask = CoarseMask(num_dim=len(out_shape))
for dim, _ in enumerate(out_shape):
if dim == cat_dim:
if mask.mask_index[dim] is None:
continue
device = mask.mask_index[dim].device
# calculate the offset of the mask
pos = in_order.index(last_visited)
offsets = [in_shape[i][cat_dim]
for i, _ in enumerate(in_shape)]
offset = 0
for i in range(pos):
offset += offsets[i]
_tmp_mask = (mask.mask_index[dim] + offset).to(device)
output_mask.mask_index[dim] = _tmp_mask
else:
# directly copy the mask
if mask.mask_index[dim] is not None:
output_mask.mask_index[dim] = mask.mask_index[dim].data.clone(
)
module_masks.set_output_mask(output_mask)
return module_masks.output_mask
# If this cat node is already visited, we need
# validating if the mask is legel, for cat operation,
# the mask on the 'cat_dim' dimension should be stitched
# together. In the other dimensions, the mask should be
# the same, else the mask is not legal.
for dim, _ in enumerate(out_shape):
if dim == cat_dim:
if mask.mask_index[dim] is None:
continue
pos = in_order.index(last_visited)
offsets = [in_shape[i][cat_dim] for i, _ in enumerate(in_shape)]
offset = 0
for i in range(pos):
offset += offsets[i]
device = mask.mask_index[dim].device
new_mask = mask.mask_index[dim] + offset
module_masks.output_mask.mask_index[dim] = CoarseMask.merge_index(
module_masks.output_mask.mask_index[dim], new_mask).to(device)
else:
assert module_masks.output_mask.eq_on_dim(mask, dim)
return module_masks.output_mask
def add_inshape(module_masks, mask):
"""
Inference the output mask of the add operation from the
input mask.
"""
assert isinstance(mask, CoarseMask)
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
# module_masks.input_mask = mask
return mask
# If alreay visited, validate if have the conflict
# if the mask is different with previous input_mask
# then there is a mask confilct.
if mask != module_masks.input_mask:
raise Exception('Mask conflict happenes!')
return None
def batchnorm2d_inshape(module_masks, mask): def batchnorm2d_inshape(module_masks, mask):
""" """
We assume only the second dimension has coarse grained mask We assume only the second dimension has coarse grained mask
...@@ -199,6 +393,7 @@ def batchnorm2d_inshape(module_masks, mask): ...@@ -199,6 +393,7 @@ def batchnorm2d_inshape(module_masks, mask):
module_masks.set_param_masks('bias', weight_cmask) module_masks.set_param_masks('bias', weight_cmask)
return mask return mask
def linear_inshape(module_masks, mask): def linear_inshape(module_masks, mask):
""" """
Coarse grained input mask does not change the shape of weights and output tensor Coarse grained input mask does not change the shape of weights and output tensor
...@@ -221,6 +416,7 @@ def linear_inshape(module_masks, mask): ...@@ -221,6 +416,7 @@ def linear_inshape(module_masks, mask):
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
return None return None
def view_inshape(module_masks, mask, shape): def view_inshape(module_masks, mask, shape):
""" """
This is a limited support This is a limited support
...@@ -246,7 +442,8 @@ def view_inshape(module_masks, mask, shape): ...@@ -246,7 +442,8 @@ def view_inshape(module_masks, mask, shape):
assert shape['in_shape'][0] == shape['out_shape'][0] assert shape['in_shape'][0] == shape['out_shape'][0]
assert len(shape['in_shape']) == 4 assert len(shape['in_shape']) == 4
assert len(shape['out_shape']) == 2 assert len(shape['out_shape']) == 2
assert shape['out_shape'][1] == shape['in_shape'][1]*shape['in_shape'][2]*shape['in_shape'][3] assert shape['out_shape'][1] == shape['in_shape'][1] * \
shape['in_shape'][2]*shape['in_shape'][3]
assert isinstance(mask, CoarseMask) assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None assert mask.mask_index[1] is not None
...@@ -260,7 +457,7 @@ def view_inshape(module_masks, mask, shape): ...@@ -260,7 +457,7 @@ def view_inshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3] step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]: for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)]) index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask) module_masks.set_output_mask(output_cmask)
return output_cmask return output_cmask
...@@ -271,6 +468,28 @@ def size_inshape(module_masks, mask): ...@@ -271,6 +468,28 @@ def size_inshape(module_masks, mask):
""" """
return None return None
def mean_inshape(module_masks, mask, shape):
"""
Similar to view operation, currently mask inference only supports
the mean operation on the 3rd and 4th dimensions.
"""
assert shape['in_shape'][0] == shape['out_shape'][0]
assert shape['out_shape'][1] == shape['in_shape'][1]
assert len(shape['in_shape']) == 4
assert len(shape['out_shape']) == 2
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
module_masks.set_input_mask(mask)
output_cmask = CoarseMask(num_dim=2)
output_cmask.add_index_mask(dim=1, index=mask.mask_index[1])
module_masks.set_output_mask(output_cmask)
return output_cmask
def maxpool2d_inshape(module_masks, mask): def maxpool2d_inshape(module_masks, mask):
""" """
Assume only the second dimension is masked Assume only the second dimension is masked
...@@ -292,11 +511,14 @@ def maxpool2d_inshape(module_masks, mask): ...@@ -292,11 +511,14 @@ def maxpool2d_inshape(module_masks, mask):
assert mask.mask_index[0] is None assert mask.mask_index[0] is None
assert mask.mask_index[2] is None assert mask.mask_index[2] is None
assert mask.mask_index[3] is None assert mask.mask_index[3] is None
assert module_masks.input_mask is None if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
# assert module_masks.input_mask is None
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
return mask return mask
def relu_inshape(module_masks, mask): def relu_inshape(module_masks, mask):
""" """
Parameters Parameters
...@@ -313,11 +535,17 @@ def relu_inshape(module_masks, mask): ...@@ -313,11 +535,17 @@ def relu_inshape(module_masks, mask):
""" """
assert isinstance(mask, CoarseMask) assert isinstance(mask, CoarseMask)
# TODO: double check this assert, is it possible that a module is passed twice # TODO: double check this assert, is it possible that a module is passed twice
assert module_masks.input_mask is None, "A relu op can only be processed once" if module_masks.input_mask is not None:
# check if has a mask conflict
assert module_masks.input_mask == mask
# No need to pass the mask again
return None
# assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
return mask return mask
def batchnorm2d_mask(module_masks, mask): def batchnorm2d_mask(module_masks, mask):
""" """
Infer input and output shape from weight mask Infer input and output shape from weight mask
...@@ -353,6 +581,7 @@ def batchnorm2d_mask(module_masks, mask): ...@@ -353,6 +581,7 @@ def batchnorm2d_mask(module_masks, mask):
module_masks.set_output_mask(output_cmask) module_masks.set_output_mask(output_cmask)
return input_cmask, output_cmask return input_cmask, output_cmask
def conv2d_mask(module_masks, mask): def conv2d_mask(module_masks, mask):
""" """
Infer input and output shape from weight mask Infer input and output shape from weight mask
...@@ -429,6 +658,7 @@ def conv2d_mask(module_masks, mask): ...@@ -429,6 +658,7 @@ def conv2d_mask(module_masks, mask):
module_masks.output_mask.merge(output_cmask) module_masks.output_mask.merge(output_cmask)
return None, module_masks.output_mask return None, module_masks.output_mask
def conv2d_inshape(module_masks, mask): def conv2d_inshape(module_masks, mask):
""" """
Shape change of input tensor does not affect the shape of its output tensor Shape change of input tensor does not affect the shape of its output tensor
...@@ -446,10 +676,16 @@ def conv2d_inshape(module_masks, mask): ...@@ -446,10 +676,16 @@ def conv2d_inshape(module_masks, mask):
The mask of its output tensor The mask of its output tensor
""" """
assert isinstance(mask, CoarseMask) assert isinstance(mask, CoarseMask)
assert module_masks.input_mask is None if module_masks.input_mask is None:
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
else:
# the same conv layer may be accessed more
# than once, such as a concat operation.
assert module_masks.input_mask <= mask
module_masks.input_mask.merge(mask)
return None return None
def conv2d_outshape(module_masks, mask): def conv2d_outshape(module_masks, mask):
""" """
Assume only the second dimension is masked Assume only the second dimension is masked
...@@ -487,4 +723,3 @@ def conv2d_outshape(module_masks, mask): ...@@ -487,4 +723,3 @@ def conv2d_outshape(module_masks, mask):
module_masks.set_param_masks('bias', bias_cmask) module_masks.set_param_masks('bias', bias_cmask)
# input shape is not changed # input shape is not changed
return None return None
\ No newline at end of file
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import os
import logging import logging
import torch import torch
import numpy as np import numpy as np
from .shape_dependency import ChannelDependency from .shape_dependency import ChannelDependency, GroupDependency, CatPaddingDependency
# logging.basicConfig(level = logging.DEBUG) # logging.basicConfig(level = logging.DEBUG)
_logger = logging.getLogger('FixMaskConflict') _logger = logging.getLogger('FixMaskConflict')
class MaskConflict: def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
def __init__(self, mask_file, model=None, dummy_input=None, graph=None): """
MaskConflict fix the mask conflict for the channel dependencies
and group dependency.
Parameters
----------
masks : dict/str
A dict object that stores the masks or the path of the mask file
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
if isinstance(masks, str):
# if the input is the path of the mask_file
assert os.path.exists(masks)
masks = torch.load(masks)
# if the user uses the model and dummy_input to trace the model, we
# should get the traced model handly, so that, we only trace the
# model once, GroupMaskConflict and ChannelMaskConflict will reuse
# this traced model.
if traced is None:
assert model is not None and dummy_input is not None
with torch.onnx.set_training(model, False):
# We need to trace the model in this way, else it will have problems
traced = torch.jit.trace(model, dummy_input)
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced)
masks = fix_group_mask.fix_mask()
fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced)
masks = fix_channel_mask.fix_mask()
padding_cat_mask = CatMaskPadding(masks, model, dummy_input, traced)
masks = padding_cat_mask.fix_mask()
return masks
class MaskFix:
def __init__(self, masks, model=None, dummy_input=None, traced=None):
# check if the parameters are valid
parameter_valid = False
if traced is not None:
parameter_valid = True
elif (model is not None) and (dummy_input is not None):
parameter_valid = True
if not parameter_valid:
raise Exception('The input parameters is invalid!')
self.model = model
self.dummy_input = dummy_input
self.traced = traced
self.masks = masks
def fix_mask(self):
raise NotImplementedError
def export(self, path):
"""
Export the masks after fixing the conflict to file.
"""
torch.save(self.masks, path)
class CatMaskPadding(MaskFix):
def __init__(self, masks, model, dummy_input=None, traced=None):
"""
CatMaskPadding find the layers whose output tensor is passed
to the same cat operation. The cat operation concatnates the
masks of the input tensors as the output mask, so when some
of the input layers of the cat operation are not pruned, we still
need to pass the masks of these non-pruned layers(the mask are
all ones) to the cat operation to ensure the shape of the output
mask is right.
Parameters
----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super(CatMaskPadding, self).__init__(masks, model, dummy_input, traced)
def fix_mask(self):
cat_padding_depen = CatPaddingDependency(self.model, self.dummy_input, self.traced)
name_to_module = {}
for name, module in self.model.named_modules():
name_to_module[name] = module
depen = cat_padding_depen.dependency_sets
for layers in depen:
device = None
count = 0
for layer in layers:
if layer in self.masks:
count += 1
if device is None:
device = self.masks[layer]['weight'].device
if count == 0:
# no layer is pruned
continue
elif count == len(layers):
# all the layers have been pruned
continue
# pad the mask for the non-pruned layers
for layer in layers:
module = name_to_module[layer]
w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device)
b_mask = None
if hasattr(module, 'bias'):
b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight':w_mask, 'bias':b_mask}
return self.masks
class GroupMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None):
"""
GroupMaskConflict fix the mask conflict between the layers that
has group dependecy with each other.
Parameters
----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super(GroupMaskConflict, self).__init__(masks, model, dummy_input, traced)
def fix_mask(self):
"""
Fix the mask conflict before the mask inference for the layers that
has group dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
group_depen = GroupDependency(self.model, self.dummy_input, self.traced)
depens = group_depen.dependency
_logger.info(depens)
for layername in depens:
group = depens[layername]
if layername not in self.masks:
# this layer not pruned
continue
w_mask = self.masks[layername]['weight']
shape = w_mask.size()
count = np.prod(shape[1:])
all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist()
all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist()
if len(all_ones) + len(all_zeros) < w_mask.size(0):
# In fine-grained pruning, skip this layer
_logger.info('Layers %s using fine-grained pruning', layername)
continue
assert shape[0] % group == 0
# Find the number of masked filter for each group (mini_masked).
# Because we have to keep the pruned filter can still
# be divided into the same number of groups, so we only can
# prune mini_masked filters for each group.
step = shape[0] / group
group_masked = []
for i in range(group):
_start = step * i
_end = step * (i+1)
_tmp_list = list(filter(lambda x: _start <= x and x < _end, all_zeros))
group_masked.append(_tmp_list)
mini_masked = min([len(x) for x in group_masked])
for gm in group_masked:
for i in range(mini_masked, len(gm)):
# To keep the output channel number still being divisible to
# groups, we set the masks of following filters to be zero.
pos = gm[i]
self.masks[layername]['weight'][pos] = torch.ones(shape[1:])
if hasattr(self.masks[layername], 'bias'):
self.masks[layername]['bias'][pos] = 1
return self.masks
class ChannelMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None):
""" """
MaskConflict fix the mask conflict between the layers that ChannelMaskConflict fix the mask conflict between the layers that
has channel dependecy with each other. has channel dependecy with each other.
Parameters Parameters
---------- ----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module model : torch.nn.Module
model to fix the mask conflict model to fix the mask conflict
dummy_input : torch.Tensor dummy_input : torch.Tensor
input example to trace the model input example to trace the model
mask_file : str graph : torch._C.torch.jit.TopLevelTracedModule
the path of the original mask file
graph : torch._C.Graph
the traced graph of the target model, is this parameter is not None, the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph. we donnot use the model and dummpy_input to get the trace graph.
""" """
# check if the parameters are valid super(ChannelMaskConflict, self).__init__(masks, model, dummy_input, traced)
parameter_valid = False
if graph is not None:
parameter_valid = True
elif (model is not None) and (dummy_input is not None):
parameter_valid = True
if not parameter_valid:
raise Exception('The input parameters is invalid!')
self.model = model
self.dummy_input = dummy_input
self.graph = graph
self.mask_file = mask_file
self.masks = torch.load(self.mask_file)
def fix_mask_conflict(self): def fix_mask(self):
""" """
Fix the mask conflict before the mask inference for the layers that Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the has shape dependencies. This function should be called before the
mask inference of the 'speedup' module. mask inference of the 'speedup' module.
""" """
channel_depen = ChannelDependency(self.model, self.dummy_input, self.graph) channel_depen = ChannelDependency(self.model, self.dummy_input, self.traced)
depen_sets = channel_depen.dependency_sets depen_sets = channel_depen.dependency_sets
for dset in depen_sets: for dset in depen_sets:
if len(dset) == 1: if len(dset) == 1:
...@@ -53,11 +233,18 @@ class MaskConflict: ...@@ -53,11 +233,18 @@ class MaskConflict:
continue continue
channel_remain = set() channel_remain = set()
fine_grained = False fine_grained = False
out_channels = None
# A flag that represents if all the layers in
# the dependency set are pruned
all_pruned = True
for name in dset: for name in dset:
if name not in self.masks: if name not in self.masks:
# this layer is not pruned # this layer is not pruned
all_pruned = False
continue continue
w_mask = self.masks[name]['weight'] w_mask = self.masks[name]['weight']
if out_channels is None:
out_channels = w_mask.size(0)
shape = w_mask.size() shape = w_mask.size()
count = np.prod(shape[1:]) count = np.prod(shape[1:])
all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist() all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist()
...@@ -74,8 +261,19 @@ class MaskConflict: ...@@ -74,8 +261,19 @@ class MaskConflict:
# Update the masks for the layers in the dependency set # Update the masks for the layers in the dependency set
if fine_grained: if fine_grained:
continue continue
if not all_pruned:
# if some layer are not pruned at all
# then all the layers in this dependency set
# cannot be pruned due to the shape dependency.
channel_remain.update(range(out_channels))
ori_channels = 0 ori_channels = 0
for name in dset: for name in dset:
if name not in self.masks:
# this layer is not pruned at all
# in this case, all_pruned is False
# and the other layers in the same dset
# will not be pruned either.
continue
mask = self.masks[name] mask = self.masks[name]
w_shape = mask['weight'].size() w_shape = mask['weight'].size()
ori_channels = w_shape[0] ori_channels = w_shape[0]
...@@ -88,9 +286,3 @@ class MaskConflict: ...@@ -88,9 +286,3 @@ class MaskConflict:
pruned_filters = set(list(range(ori_channels)))-channel_remain pruned_filters = set(list(range(ori_channels)))-channel_remain
_logger.info(str(sorted(pruned_filters))) _logger.info(str(sorted(pruned_filters)))
return self.masks return self.masks
def export(self, path):
"""
Export the masks after fixing the conflict to file.
"""
torch.save(self.masks, path)
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
from nni._graph_utils import TorchModuleGraph from nni._graph_utils import TorchModuleGraph
__all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency']
CONV_TYPE = 'aten::_convolution' CONV_TYPE = 'aten::_convolution'
ADD_TYPES = ['aten::add', 'aten::add_'] ADD_TYPES = ['aten::add', 'aten::add_']
...@@ -13,7 +14,27 @@ CAT_TYPE = 'aten::cat' ...@@ -13,7 +14,27 @@ CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency') logger = logging.getLogger('Shape_Dependency')
class ChannelDependency: class Dependency:
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
Build the graph for the model.
"""
# check if the input is legal
if traced_model is None:
# user should provide model & dummy_input to trace
# the model or a already traced model
assert model is not None and dummy_input is not None
self.graph = TorchModuleGraph(model, dummy_input, traced_model)
self.dependency = dict()
self.build_dependency()
def build_dependency(self):
raise NotImplementedError
def export(self, filepath):
raise NotImplementedError
class ChannelDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None): def __init__(self, model=None, dummy_input=None, traced_model=None):
""" """
This model analyze the channel dependencis between the conv This model analyze the channel dependencis between the conv
...@@ -29,13 +50,7 @@ class ChannelDependency: ...@@ -29,13 +50,7 @@ class ChannelDependency:
if we alreay has the traced graph of the target model, we donnot if we alreay has the traced graph of the target model, we donnot
need to trace the model again. need to trace the model again.
""" """
# check if the input is legal super(ChannelDependency, self).__init__(model, dummy_input, traced_model)
if traced_model is None:
# user should provide model & dummy_input to trace the model or a already traced model
assert model is not None and dummy_input is not None
self.graph = TorchModuleGraph(model, dummy_input, traced_model)
self.dependency = dict()
self.build_channel_dependency()
def _get_parent_layers(self, node): def _get_parent_layers(self, node):
""" """
...@@ -66,7 +81,7 @@ class ChannelDependency: ...@@ -66,7 +81,7 @@ class ChannelDependency:
queue.append(parent) queue.append(parent)
return parent_layers return parent_layers
def build_channel_dependency(self): def build_dependency(self):
""" """
Build the channel dependency for the conv layers Build the channel dependency for the conv layers
in the model. in the model.
...@@ -119,7 +134,7 @@ class ChannelDependency: ...@@ -119,7 +134,7 @@ class ChannelDependency:
Set 2,layer1.0.conv1 Set 2,layer1.0.conv1
Set 3,layer1.1.conv1 Set 3,layer1.1.conv1
""" """
header = ['Dependency Set', 'Convolutional Layers'] header = ['Dependency Set', 'Layers']
setid = 0 setid = 0
visited = set() visited = set()
with open(filepath, 'w') as csvf: with open(filepath, 'w') as csvf:
...@@ -166,3 +181,200 @@ class ChannelDependency: ...@@ -166,3 +181,200 @@ class ChannelDependency:
tmp_set.add(other) tmp_set.add(other)
d_sets.append(tmp_set) d_sets.append(tmp_set)
return d_sets return d_sets
class CatPaddingDependency(ChannelDependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
super(CatPaddingDependency, self).__init__(model, dummy_input, traced_model)
def build_dependency(self):
"""
Build the cat padding dependencies.
If the output features of several layers are stitched together
by cat operation, then these layers have cat padding dependencies.
This is because when inferring the cat mask, we need all the input
masks for the cat operation. At this time we need to know the source
of all input vectors of a cat operation.
"""
for node in self.graph.nodes_py.nodes_op:
parent_layers = []
if node.op_type == CAT_TYPE:
parent_layers = self._get_parent_layers(node)
dependency_set = set(parent_layers)
# merge the dependencies
for parent in parent_layers:
if parent in self.dependency:
dependency_set.update(self.dependency[parent])
# save the dependencies
for _node in dependency_set:
self.dependency[_node] = dependency_set
@property
def dependency_sets(self):
d_sets = []
visited = set()
for nodename in self.dependency:
if nodename in visited:
continue
d_sets.append(self.dependency[nodename])
return d_sets
def export(self, filepath):
"""
Export the dependencies into a file.
In the output file, each line contains a set of layers
whose output features are stitched together by the cat
operation.
output example:
Dependency Set, Layers
set1, Conv1, Conv2
set2, Conv3, Conv4
"""
header = ['Dependency Set', 'Layers']
setid = 0
with open(filepath, 'w') as csvf:
csv_w = csv.writer(csvf, delimiter=',')
csv_w.writerow(header)
for layers in self.dependency_sets:
setid += 1
row = ['Set %d' % setid]
row.extend(list(layers))
csv_w.writerow(row)
class GroupDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
This model analyze the group dependencis between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super(GroupDependency, self).__init__(model, dummy_input, traced_model)
def _get_parent_convs(self, node):
"""
Find the nearest father conv layers for the target node.
Parameters
---------
node : torch._C.Node
target node.
Returns
-------
parent_layers : list
nearest father conv layers for the target node. Due to the group
dependency only exists between the conv layers, so we only find
the parent conv layers.
"""
parent_layers = []
# the input node is a Conv node
predeessors = self.graph.find_predecessors(node.unique_name)
predeessors = [self.graph.name_to_node[x] for x in predeessors]
queue = predeessors
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d':
# find the first met conv
parent_layers.append(curnode.name)
continue
parents = self.graph.find_predecessors(curnode.unique_name)
parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents:
queue.append(parent)
return parent_layers
def _get_conv_groups(self, node_group):
"""
Get the number of groups for a convolutional layer.
Parameters
----------
node_group : NodePyGroup
target node.
Returns
-------
group : int
the number of the groups of the target conv layer.
"""
cpp_conv = list(filter(lambda x: x.kind() == CONV_TYPE, node_group.node_cpps))
assert len(cpp_conv) == 1
cpp_conv = cpp_conv[0]
inputs = list(cpp_conv.inputs())
# get the number of the group from the input parameters
group = inputs[8].toIValue()
return group
def build_dependency(self):
"""
Build the channel dependency for the conv layers
in the model. This function return the group number
of each conv layers. Note that, here, the group count
of conv layers may be larger than their originl groups.
This is because that the input channel will also be grouped
for the group conv layers. To make this clear, assume we
have two group conv layers: conv1(group=2), conv2(group=4).
conv2 takes the output features of conv1 as input.
Then we have to the filters of conv1 can still be
divided into 4 groups after filter pruning, because
the input channels of conv2 shoule be divided into
4 groups.
Returns
-------
self.dependency : dict
key: the name of conv layers, value: the minimum value that the number of
filters should be divisible to.
"""
for node in self.graph.nodes_py.nodes_op:
if node.op_type == 'Conv2d':
group = self._get_conv_groups(node)
if node.name in self.dependency:
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
self.dependency[node.name] = max(self.dependency[node.name], group)
else:
self.dependency[node.name] = group
if group > 1:
# for the conv layer whose group is larger than 1, it will require the number
# of output channels of their parent conv layer to be divisible by group.
parent_convs = self._get_parent_convs(node)
for parent in parent_convs:
if parent in self.dependency:
self.dependency[parent] = max(self.dependency[parent], group)
else:
self.dependency[parent] = group
return self.dependency
def export(self, filepath):
"""
export the group dependency to a csv file.
Each line describes a convolution layer, the
first part of each line is the Pytorch module
name of the conv layer. The second part of each
line is the group count of the filters in this layer.
Note that, the group count may be larger than this
layers original group number.
output example:
Conv layer, Groups
Conv1, 1
Conv2, 2
Conv3, 4
"""
header = ['Conv Layer Name', 'Group']
with open(filepath, 'w') as csvf:
csv_w = csv.writer(csvf, delimiter=',')
csv_w.writerow(header)
for name in self.dependency:
group = self.dependency[name]
csv_w.writerow([name, group])
...@@ -11,13 +11,13 @@ import numpy as np ...@@ -11,13 +11,13 @@ import numpy as np
from nni.compression.torch import L1FilterPruner from nni.compression.torch import L1FilterPruner
from nni.compression.torch.utils.shape_dependency import ChannelDependency from nni.compression.torch.utils.shape_dependency import ChannelDependency
from nni.compression.torch.utils.mask_conflict import MaskConflict from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
prefix = 'analysis_test' prefix = 'analysis_test'
model_names = ['alexnet', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg19', model_names = ['alexnet', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg19',
'resnet18', 'resnet34', 'squeezenet1_1', 'resnet18', 'resnet34', 'squeezenet1_1',
'shufflenet_v2_x1_0', 'mobilenet_v2', 'wide_resnet50_2'] 'mobilenet_v2', 'wide_resnet50_2']
channel_dependency_ground_truth = { channel_dependency_ground_truth = {
'resnet18': [{'layer1.0.conv2', 'layer1.1.conv2', 'conv1'}, 'resnet18': [{'layer1.0.conv2', 'layer1.1.conv2', 'conv1'},
...@@ -49,8 +49,12 @@ channel_dependency_ground_truth = { ...@@ -49,8 +49,12 @@ channel_dependency_ground_truth = {
'vgg13': [], 'vgg13': [],
'vgg19': [], 'vgg19': [],
'squeezenet1_1': [], 'squeezenet1_1': [],
'googlenet': [], 'googlenet': []
'shufflenet_v2_x1_0': [] # comments the shufflenet temporary
# because it has the listunpack operation which
# will lead to a graph construction error.
# support the listunpack in the next release.
# 'shufflenet_v2_x1_0': []
} }
unittest.TestLoader.sortTestMethodsUsing = None unittest.TestLoader.sortTestMethodsUsing = None
...@@ -111,9 +115,8 @@ class AnalysisUtilsTest(TestCase): ...@@ -111,9 +115,8 @@ class AnalysisUtilsTest(TestCase):
pruner.export_model(ck_file, mask_file) pruner.export_model(ck_file, mask_file)
pruner._unwrap_model() pruner._unwrap_model()
# Fix the mask conflict # Fix the mask conflict
mf = MaskConflict(mask_file, net, dummy_input) fixed_mask = fix_mask_conflict(mask_file, net, dummy_input)
fixed_mask = mf.fix_mask_conflict()
mf.export(os.path.join(outdir, '%s_fixed_mask' % name))
# use the channel dependency groud truth to check if # use the channel dependency groud truth to check if
# fix the mask conflict successfully # fix the mask conflict successfully
for dset in channel_dependency_ground_truth[name]: for dset in channel_dependency_ground_truth[name]:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import os import os
import numpy as np import numpy as np
import torch import torch
import torchvision.models as models
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision.models.vgg import vgg16 from torchvision.models.vgg import vgg16
...@@ -13,7 +14,17 @@ from unittest import TestCase, main ...@@ -13,7 +14,17 @@ from unittest import TestCase, main
from nni.compression.torch import L1FilterPruner, apply_compression_results, ModelSpeedup from nni.compression.torch import L1FilterPruner, apply_compression_results, ModelSpeedup
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2
# the relative distance
RELATIVE_THRESHOLD = 0.01
# Because of the precision of floating-point numbers, some errors
# between the original output tensors(without speedup) and the output
# tensors of the speedup model are normal. When the output tensor itself
# is small, such errors may exceed the relative threshold, so we also add
# an absolute threshold to determine whether the final result is correct.
# The error should meet the RELATIVE_THREHOLD or the ABSOLUTE_THRESHOLD.
ABSOLUTE_THRESHOLD = 0.0001
class BackboneModel1(nn.Module): class BackboneModel1(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -72,6 +83,27 @@ def prune_model_l1(model): ...@@ -72,6 +83,27 @@ def prune_model_l1(model):
pruner.compress() pruner.compress()
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE) pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
def generate_random_sparsity(model):
cfg_list = []
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
sparsity = np.random.uniform(0.5, 0.99)
cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
'sparsity': sparsity})
return cfg_list
def zero_bn_bias(model):
with torch.no_grad():
for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d) \
or isinstance(module, nn.BatchNorm3d) \
or isinstance(module, nn.BatchNorm1d):
shape = module.bias.data.size()
device = module.bias.device
module.bias.data = torch.zeros(shape).to(device)
shape = module.running_mean.data.size()
module.running_mean = torch.zeros(shape).to(device)
class SpeedupTestCase(TestCase): class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self): def test_speedup_vgg16(self):
prune_model_l1(vgg16()) prune_model_l1(vgg16())
...@@ -85,10 +117,6 @@ class SpeedupTestCase(TestCase): ...@@ -85,10 +117,6 @@ class SpeedupTestCase(TestCase):
assert model.features[2].out_channels == int(orig_model.features[2].out_channels * SPARSITY) assert model.features[2].out_channels == int(orig_model.features[2].out_channels * SPARSITY)
assert model.classifier[0].in_features == int(orig_model.classifier[0].in_features * SPARSITY) assert model.classifier[0].in_features == int(orig_model.classifier[0].in_features * SPARSITY)
#def test_speedup_resnet(self):
#TODO support resnet
#model = resnet18()
def test_speedup_bigmodel(self): def test_speedup_bigmodel(self):
prune_model_l1(BigModel()) prune_model_l1(BigModel())
model = BigModel() model = BigModel()
...@@ -116,6 +144,36 @@ class SpeedupTestCase(TestCase): ...@@ -116,6 +144,36 @@ class SpeedupTestCase(TestCase):
assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY) assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY)
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY) assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2']:
Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False).to(device)
net.eval() # this line is necessary
# random generate the prune config for the pruner
cfgs = generate_random_sparsity(net)
pruner = L1FilterPruner(net, cfgs)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
speedup_model = Model().to(device)
speedup_model.eval()
state_dict = torch.load(MODEL_FILE)
speedup_model.load_state_dict(state_dict)
zero_bn_bias(net)
zero_bn_bias(speedup_model)
data = torch.ones(BATCH_SIZE, 3, 224, 224).to(device)
ms = ModelSpeedup(speedup_model, data, MASK_FILE)
ms.speedup_model()
ori_out = net(data)
speeded_out = speedup_model(data)
ori_sum = torch.sum(ori_out).item()
speeded_sum = torch.sum(speeded_out).item()
print('Sum of the output of %s (before speedup):'%model_name, ori_sum)
print('Sum of the output of %s (after speedup):'%model_name, speeded_sum)
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def tearDown(self): def tearDown(self):
os.remove(MODEL_FILE) os.remove(MODEL_FILE)
os.remove(MASK_FILE) os.remove(MASK_FILE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment