Unverified Commit 8bc74a2c authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Speedup supports channel pruning (#2906)

parent f43719a8
......@@ -426,6 +426,36 @@ class TorchModuleGraph(TorchGraph):
cat_info['in_shape'] = input_shapes
return cat_info
def _extract_linear_shape_info(self, node_group):
"""
Extract linear shape input/output tensor shape info from its aten::addmm op.
Parameters
----------
node_group : NodePyGroup
NodePyGroup object associated with the linear module.
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
for cpp_node in node_group.node_cpps:
if cpp_node.kind() == 'aten::addmm':
# https://github.com/pytorch/pytorch/blob/1.6/torch/nn/functional.py#L1682
# inputs of aten::addmm:
# inputs[0] is bias
# inputs[1] is input data
# inputs[2] is weight
t_input = list(cpp_node.inputs())[1]
t_output = cpp_node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
return None
def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
......@@ -701,6 +731,8 @@ class TorchModuleGraph(TorchGraph):
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
node_group.auxiliary = self._extract_shape_info(cpp_node)
elif node_group.op_type == 'Linear':
node_group.auxiliary = self._extract_linear_shape_info(node_group)
elif node_group.op_type == CAT_KIND:
# get the detail information for cat func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import logging
from schema import And, Optional
from schema import And, Optional, SchemaError
from nni._graph_utils import TorchModuleGraph
from nni.compression.torch.utils.shape_dependency import ChannelDependency, GroupDependency
from .constants import MASKER_DICT
......@@ -186,12 +186,16 @@ class _StructuredFilterPruner(OneshotPruner):
def validate_config(self, model, config_list):
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): ['Conv2d'],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
for config in config_list:
if 'exclude' not in config and 'sparsity' not in config:
raise SchemaError('Either sparisty or exclude must be specified!')
def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None):
"""
......
......@@ -116,15 +116,19 @@ def replace_conv2d(conv, mask):
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]
_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
groups = conv.groups
if conv.in_channels == conv.out_channels == conv.groups:
# remove groups for depthwise layers
assert in_channels == out_channels
groups = in_channels
_logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d", mask.module_name, in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
groups=groups,
bias=conv.bias is not None,
padding_mode=conv.padding_mode)
......@@ -142,13 +146,16 @@ def replace_conv2d(conv, mask):
# channal is also divided into serveral groups and each group
# filter may have different input channel indexes.
input_step = int(conv.in_channels / conv.groups)
in_channels_group = int(in_channels / conv.groups)
filter_step = int(out_channels / conv.groups)
if mask.input_mask is not None:
in_channels_group = int(in_channels / groups)
filter_step = int(out_channels / groups)
if mask.input_mask is not None and not (in_channels == out_channels == groups):
for groupid in range(conv.groups):
start = groupid * input_step
end = (groupid + 1) * input_step
current_input_index = list(filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
if not current_input_index:
# there is no kept channel in current group
continue
# shift the global index into the group index
current_input_index = [x-start for x in current_input_index]
# if the groups is larger than 1, the input channels of each
......
......@@ -4,34 +4,13 @@
import logging
import torch
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
from nni.compression.torch.utils.utils import get_module_by_name
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape, set_conv_prune_dim
_logger = logging.getLogger(__name__)
def get_module_by_name(model, module_name):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list = module_name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
leaf_module = getattr(model, name_list[-1])
return model, leaf_module
class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask
......@@ -87,7 +66,8 @@ class ModelSpeedup:
if module_name in self.inferred_masks:
module_masks = self.inferred_masks[module_name]
else:
module_masks = ModuleMasks(module_name)
_, m = get_module_by_name(self.bound_model, module_name)
module_masks = ModuleMasks(module_name, m)
self.inferred_masks[module_name] = module_masks
m_type = self.torch_graph.name_to_node[module_name].op_type
......@@ -98,7 +78,12 @@ class ModelSpeedup:
raise RuntimeError(
"Has not supported infering input/output shape from mask for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if m_type in ['Linear']:
input_cmask, output_cmask = infer_from_mask[m_type](
module_masks, mask, self.torch_graph.name_to_node[module_name].auxiliary
)
else:
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
_logger.debug("in_shape is not None")
if not m_type in infer_from_inshape:
......@@ -124,7 +109,10 @@ class ModelSpeedup:
raise RuntimeError(
"Has not supported infering input shape from output shape for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape, self.torch_graph.name_to_node[module_name].auxiliary)
else:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if input_cmask:
predecessors = self.torch_graph.find_predecessors(module_name)
......@@ -178,7 +166,6 @@ class ModelSpeedup:
else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type))
def speedup_model(self):
"""
There are basically two steps:
......@@ -187,8 +174,11 @@ class ModelSpeedup:
"""
training = self.bound_model.training
_logger.info("start to speed up the model")
_logger.info("fix the mask conflict of the interdependent layers")
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
_, conv_prune_dim = fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
set_conv_prune_dim(conv_prune_dim)
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info("replace compressed modules...")
......
......@@ -6,8 +6,22 @@ One is given output shape, infer its input shape and initialization parameters (
The other is given input shape, infer its output shape and initialization parameters (e.g., weight's shape)
"""
import logging
import torch
_logger = logging.getLogger(__name__)
conv_prune_dim = -1
def set_conv_prune_dim(dim):
"""
Parameters:
dim: int
0: filter pruning
1: channel pruning
"""
global conv_prune_dim
conv_prune_dim = dim
class CoarseMask:
"""
......@@ -160,7 +174,7 @@ class ModuleMasks:
The masks of a module, including the masks for weights, inputs, output
"""
def __init__(self, module_name):
def __init__(self, module_name, module=None):
"""
Parameters
----------
......@@ -168,6 +182,7 @@ class ModuleMasks:
The name of the module or function
"""
self.module_name = module_name
self.module = module
self.param_masks = dict()
self.input_mask = None
self.output_mask = None
......@@ -202,8 +217,8 @@ class ModuleMasks:
self.output_mask = mask
def __repr__(self):
return 'input_mask: {}, output_mask: {}, param_masks: {}'.format(
self.input_mask, self.output_mask, self.param_masks
return 'module_name: {}, input_mask: {}, output_mask: {}, param_masks: {}'.format(
self.module_name, self.input_mask, self.output_mask, self.param_masks
)
......@@ -212,7 +227,8 @@ Infer input and output shape of a module/function from its weight mask
"""
infer_from_mask = {
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_mask(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask)
'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask),
'Linear': lambda module_masks, mask, shape: linear_mask(module_masks, mask, shape)
}
"""
......@@ -260,7 +276,34 @@ infer_from_inshape = {
Infer input and weight shape of a module/function from its output shape
"""
infer_from_outshape = {
'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask)
'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_outshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'aten::adaptive_avg_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'AvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'ReLU': lambda module_masks, mask: relu_outshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::tanh_': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::hardtanh': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::hardtanh_': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::relu_': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::add_': lambda module_masks, mask: add_outshape(module_masks, mask),
'aten::add': lambda module_mask, mask: add_outshape(module_mask, mask),
'aten::flatten': lambda module_mask, mask, shape: view_outshape(module_mask, mask, shape),
'aten::view': lambda module_masks, mask, shape: view_outshape(module_masks, mask, shape),
'aten::reshape': lambda module_masks, mask, shape: view_outshape(module_masks, mask, shape),
'aten::mean': lambda module_masks, mask, shape: mean_outshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask)
}
def dropout_inshape(module_masks, mask):
......@@ -282,7 +325,15 @@ def dropout_inshape(module_masks, mask):
module_masks.set_output_mask(mask)
return module_masks.output_mask
def dropout_outshape(module_masks, mask):
if module_masks.output_mask is None:
module_masks.set_output_mask(mask)
module_masks.set_input_mask(mask)
return module_masks.input_mask
# if alreay visited
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
return module_masks.output_mask
def cat_inshape(module_masks, mask, cat_info, last_visited):
"""
......@@ -382,6 +433,20 @@ def add_inshape(module_masks, mask):
raise Exception('Mask conflict happenes!')
return None
def add_outshape(module_masks, mask):
"""
Inference the input mask of the add operation from the
output mask.
"""
assert isinstance(mask, CoarseMask)
if module_masks.output_mask is None:
module_masks.set_output_mask(mask)
module_masks.set_input_mask(mask)
return mask
else:
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
return mask
def batchnorm2d_inshape(module_masks, mask):
"""
......@@ -412,6 +477,34 @@ def batchnorm2d_inshape(module_masks, mask):
module_masks.set_param_masks('bias', weight_cmask)
return mask
def batchnorm2d_outshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert len(mask.mask_index) in [2, 4]
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
weight_cmask = CoarseMask(num_dim=1)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', weight_cmask)
return mask
def linear_inshape(module_masks, mask):
"""
......@@ -484,6 +577,42 @@ def view_inshape(module_masks, mask, shape):
module_masks.set_output_mask(output_cmask)
return output_cmask
def view_outshape(module_masks, mask, shape):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the ```flatten``` op
mask : CoarseMask
The mask of its input tensor
shape : dict
Original shape of its input and output tensors
Returns
-------
CoarseMask
The mask of its output tensor
"""
# NOTE: the case constrained by the following four asserts
assert shape['in_shape'][0] == shape['out_shape'][0]
assert len(shape['in_shape']) == 4
assert len(shape['out_shape']) == 2
assert shape['out_shape'][1] == shape['in_shape'][1] * \
shape['in_shape'][2]*shape['in_shape'][3]
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
module_masks.set_output_mask(mask)
input_cmask = CoarseMask(num_dim=4)
index = []
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
input_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
module_masks.set_input_mask(input_cmask)
return input_cmask
def size_inshape(module_masks, mask):
"""
......@@ -513,6 +642,26 @@ def mean_inshape(module_masks, mask, shape):
module_masks.set_output_mask(output_cmask)
return output_cmask
def mean_outshape(module_masks, mask, shape):
"""
Similar to view operation, currently mask inference only supports
the mean operation on the 3rd and 4th dimensions.
"""
assert shape['in_shape'][0] == shape['out_shape'][0]
assert shape['out_shape'][1] == shape['in_shape'][1]
assert len(shape['in_shape']) == 4
assert len(shape['out_shape']) == 2
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
module_masks.set_output_mask(mask)
input_cmask = CoarseMask(num_dim=4)
input_cmask.add_index_mask(dim=1, index=mask.mask_index[1])
module_masks.set_input_mask(input_cmask)
return input_cmask
def maxpool2d_inshape(module_masks, mask):
"""
Assume only the second dimension is masked
......@@ -541,6 +690,29 @@ def maxpool2d_inshape(module_masks, mask):
module_masks.set_output_mask(mask)
return mask
def maxpool2d_outshape(module_masks, mask):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the maxpool2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def relu_inshape(module_masks, mask):
"""
......@@ -558,25 +730,44 @@ def relu_inshape(module_masks, mask):
"""
assert isinstance(mask, CoarseMask)
if module_masks.input_mask is not None:
# check if has a mask conflict
# mask conflict should be solved before speedup
assert module_masks.input_mask <= mask
# assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def relu_outshape(module_masks, mask):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the relu
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
if module_masks.output_mask is not None:
# mask conflict should be solved before speedup
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def batchnorm2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : dict
The mask of its weights, from the user provided mask file
Returns
-------
CoarseMask, CoarseMask
......@@ -601,6 +792,38 @@ def batchnorm2d_mask(module_masks, mask):
module_masks.set_output_mask(output_cmask)
return input_cmask, output_cmask
def linear_mask(module_masks, mask, shape):
"""
Infer input and output shape from weight mask with limitations:
Only support infer input mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the Linear
mask : dict
The mask of its weights, from the user provided mask file
shape: dict
Shape of its input and output tensors
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
assert 'weight' in mask
num_input_dim = len(shape['in_shape'])
# Input data of Linear module can have multiple dimensions.
# here we only support infer coarse mask on the first dimension (dimension 0)
nonzero_index = torch.nonzero(mask['weight'].sum(0), as_tuple=True)[0]
# infer shape of input tensor
input_cmask = CoarseMask(num_dim=num_input_dim)
input_cmask.add_index_mask(dim=num_input_dim-1, index=nonzero_index)
module_masks.set_input_mask(input_cmask)
return input_cmask, None
def conv2d_mask(module_masks, mask):
"""
......@@ -618,12 +841,15 @@ def conv2d_mask(module_masks, mask):
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
def convert_to_coarse_mask(mask):
def convert_to_coarse_mask(mask, dim=0):
"""
Parameters
----------
mask : dict
Weight mask from user provided mask file
dim: int
0: filter pruning
1: channel pruning
Returns
-------
......@@ -632,64 +858,69 @@ def conv2d_mask(module_masks, mask):
"""
assert 'weight' in mask
assert isinstance(mask['weight'], torch.Tensor)
assert dim in [0, 1]
weight_mask = mask['weight']
shape = weight_mask.size()
ones = torch.ones(shape[1:]).to(weight_mask.device)
zeros = torch.zeros(shape[1:]).to(weight_mask.device)
index = []
for i in range(shape[0]):
if torch.all(torch.eq(weight_mask[i], ones)):
index.append(i)
elif torch.all(torch.eq(weight_mask[i], zeros)):
continue
else:
index = None
break
sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3)
index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=True)[0]
if len(index) == weight_mask.shape[dim]: # full mask
index = None
if index is None:
return None, None, None
else:
index = torch.LongTensor(index).to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=index)
weight_cmask.add_index_mask(dim=dim, index=index)
bias_cmask = None
if 'bias' in mask and mask['bias'] is not None:
if dim == 0 and 'bias' in mask and mask['bias'] is not None:
bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0]
assert torch.all(torch.eq(index, bias_index)), \
"bias mask should be consistent with weight mask"
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask
index, weight_cmask, bias_cmask = convert_to_coarse_mask(mask)
index, weight_cmask, bias_cmask = convert_to_coarse_mask(mask, dim=conv_prune_dim)
if index is None:
# TODO: fine grained mask speedup
return None, None
# deal with coarse grain mask
# mask conflict should be solved by fix_mask_conflict before speedup
if 'weight' in module_masks.param_masks:
module_masks.param_masks['weight'].merge(weight_cmask)
module_masks.param_masks['bias'].merge(bias_cmask)
assert module_masks.param_masks['weight'] == weight_cmask
else:
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
output_cmask = CoarseMask(num_dim=4)
output_cmask.add_index_mask(dim=1, index=index)
if module_masks.output_mask is None:
module_masks.set_output_mask(output_cmask)
else:
module_masks.output_mask.merge(output_cmask)
return None, module_masks.output_mask
if conv_prune_dim == 0:
module_masks.set_param_masks('bias', bias_cmask)
io_cmask = CoarseMask(num_dim=4)
io_cmask.add_index_mask(dim=1, index=index)
if conv_prune_dim == 0:
if module_masks.output_mask is None:
module_masks.set_output_mask(io_cmask)
else:
assert module_masks.output_mask == io_cmask
return None, module_masks.output_mask
else:
if module_masks.input_mask is None:
module_masks.set_input_mask(io_cmask)
else:
assert module_masks.input_mask == io_cmask
return module_masks.input_mask, None
def conv2d_inshape(module_masks, mask):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
......@@ -701,8 +932,15 @@ def conv2d_inshape(module_masks, mask):
else:
# the same conv layer may be accessed more
# than once, such as a concat operation.
assert module_masks.input_mask <= mask
module_masks.input_mask.merge(mask)
# mask conflict should be solved by fix_mask_conflict before speedup
assert module_masks.input_mask == mask
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
......@@ -728,18 +966,25 @@ def conv2d_outshape(module_masks, mask):
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
if module_masks.output_mask is not None:
assert isinstance(module_masks.output_mask, CoarseMask)
# set shape of output
mask = module_masks.output_mask.merge(mask)
else:
if module_masks.output_mask is None:
module_masks.output_mask = mask
# infer shape of parameters
else:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
# input shape is not changed
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
......@@ -4,9 +4,10 @@ import os
import logging
import torch
import numpy as np
from .shape_dependency import ChannelDependency, GroupDependency, CatPaddingDependency
from .shape_dependency import ChannelDependency, GroupDependency, CatPaddingDependency, InputChannelDependency
from .utils import get_module_by_name
# logging.basicConfig(level = logging.DEBUG)
_logger = logging.getLogger('FixMaskConflict')
_logger = logging.getLogger(__name__)
def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
"""
......@@ -45,7 +46,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks = fix_channel_mask.fix_mask()
padding_cat_mask = CatMaskPadding(masks, model, dummy_input, traced)
masks = padding_cat_mask.fix_mask()
return masks
return masks, fix_channel_mask.conv_prune_dim
class MaskFix:
def __init__(self, masks, model=None, dummy_input=None, traced=None):
......@@ -221,74 +222,148 @@ class ChannelMaskConflict(MaskFix):
we donnot use the model and dummpy_input to get the trace graph.
"""
super(ChannelMaskConflict, self).__init__(masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model)
_logger.info('detected conv prune dim: %s', self.conv_prune_dim)
def fix_mask(self):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
mask inference of the 'speedup' module. Only structured pruning masks
are supported.
"""
channel_depen = ChannelDependency(self.model, self.dummy_input, self.traced)
if self.conv_prune_dim == 0:
channel_depen = ChannelDependency(self.model, self.dummy_input, self.traced)
else:
channel_depen = InputChannelDependency(self.model, self.dummy_input, self.traced)
depen_sets = channel_depen.dependency_sets
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)
for dset in depen_sets:
if len(dset) == 1:
# This layer has no channel dependency with other layers
if len(dset) <= 1:
continue
channel_remain = set()
# channel_masks is a list, each element is None or a vector, for example:
# [[0, 1, 1, 0, 0], [0, 0, 1, 1, 0], None], None means no channel
# is pruned.
channel_masks = []
fine_grained = False
out_channels = None
# A flag that represents if all the layers in
# the dependency set are pruned
all_pruned = True
for name in dset:
if name not in self.masks:
# this layer is not pruned
all_pruned = False
continue
w_mask = self.masks[name]['weight']
if out_channels is None:
out_channels = w_mask.size(0)
shape = w_mask.size()
count = np.prod(shape[1:])
all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist()
all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist()
if len(all_ones) + len(all_zeros) < w_mask.size(0):
# In fine-grained pruning, there is no need to check
# the shape conflict
_logger.info('Layers %s using fine-grained pruning', ','.join(dset))
fine_grained = True
break
channel_remain.update(all_ones)
_logger.debug('Layer: %s ', name)
_logger.debug('Original pruned filters: %s', str(all_zeros))
# Update the masks for the layers in the dependency set
if fine_grained or out_channels is None:
# if use the fine-grained pruner or all the layers in
# this dependency set are not pruned
if name in self.masks:
_, m = get_module_by_name(self.model, name)
assert m is not None
mask = self.masks[name]['weight']
if type(m).__name__ == 'Conv2d':
channel_mask = (mask.abs().sum(sum_idx) != 0).int()
channel_masks.append(channel_mask)
if (channel_mask.sum() * (mask.numel() / mask.shape[self.conv_prune_dim])).item() != (mask > 0).sum().item():
fine_grained = True
elif type(m).__name__ == 'Linear':
channel_masks.append((mask.abs().sum(0) != 0).int())
elif type(m).__name__ == 'BatchNorm2d':
channel_masks.append(mask.int())
else:
raise RuntimeError(f'unsupported module type: {type(m).__name__}')
else:
# no mask means not pruned, equivlent to full masks
channel_masks.append(None)
if fine_grained:
_logger.info('fine-grained mask detected, skip solving conflict for this set: %s', dset)
continue
if not all_pruned:
# if some layer are not pruned at all
# then all the layers in this dependency set
# cannot be pruned due to the shape dependency.
channel_remain.update(range(out_channels))
ori_channels = 0
if all(x is None for x in channel_masks):
continue
num_channels_list = [len(x) for x in channel_masks if x is not None]
# number of channels in same set should be identical
assert len(set(num_channels_list)) == 1
num_channels = num_channels_list[0]
for i, dim_mask in enumerate(channel_masks):
if dim_mask is None:
channel_masks[i] = torch.ones(num_channels).int()
# merge masks with 'or'
merged_channel_mask = channel_masks[0].clone()
for i in range(1, len(channel_masks)):
merged_channel_mask = ((merged_channel_mask + channel_masks[i]) != 0).int()
merged_index = torch.nonzero(merged_channel_mask, as_tuple=True)[0]
for name in dset:
if name not in self.masks:
# this layer is not pruned at all
# in this case, all_pruned is False
# and the other layers in the same dset
# will not be pruned either.
assert all(merged_channel_mask)
continue
mask = self.masks[name]
w_shape = mask['weight'].size()
ori_channels = w_shape[0]
for i in channel_remain:
mask['weight'][i] = torch.ones(w_shape[1:])
if 'bias' in mask and mask['bias'] is not None:
mask['bias'][i] = 1
_logger.info(','.join(dset))
_logger.info('Pruned Filters after fixing conflict:')
pruned_filters = set(list(range(ori_channels)))-channel_remain
_logger.info(str(sorted(pruned_filters)))
orig_mask = self.masks[name]['weight']
_, m = get_module_by_name(self.model, name)
new_mask = torch.zeros_like(orig_mask)
if type(m).__name__ == 'Conv2d':
if self.conv_prune_dim == 0:
new_mask[merged_index, :, :, :] = 1.
else:
new_mask[:, merged_index, :, :] = 1.
elif type(m).__name__ == 'Linear':
new_mask[:, merged_index] = 1.
elif type(m).__name__ == 'BatchNorm2d':
new_mask = merged_index.type_as(orig_mask)
else:
raise RuntimeError(f'unsupported module type: {type(m).__name__}')
self.masks[name]['weight'] = new_mask
if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None:
if type(m).__name__ == 'Conv2d':
assert self.conv_prune_dim == 0
self.masks[name]['bias'] = merged_channel_mask.type_as(self.masks[name]['bias'])
return self.masks
def detect_mask_prune_dim(masks, model):
"""
Detect how the masks of convolutional layers are pruned.
Parameters
----------
masks: dict
A dict object that stores the masks.
model: nn.Module
Model object which the mask can be applied on.
Returns:
-------
How the masks of convolutional layers are pruned, this depends on pruning algorithms, it should
return 1 for masks generated by AMCPruner, and returns 0 for masks generated by the rest
NNI builtin pruners.
0: filter pruning, prune filters of weights which causes channels of output feature maps are pruned.
1: channel pruning, prune kernels corresponding to each input channels which causes channels of
input feature maps are pruned.
"""
dim0_preserved, dim1_preserved = 0., 0.
dim0_num, dim1_num = 0., 0.
for module_name in masks:
_, m = get_module_by_name(model, module_name)
if m is None or type(m).__name__ != 'Conv2d':
continue
mask = masks[module_name]['weight'].clone()
assert (mask >= 0).sum() == mask.numel(), \
"mask values should be greater than or equal to 0."
mask = (mask > 0).int()
mask = mask.view(mask.shape[0], mask.shape[1], -1)
dim0_mask = (mask.sum((1, 2)) > 0).int()
dim1_mask = (mask.sum((0, 2)) > 0).int()
dim0_preserved += dim0_mask.sum().item()
dim1_preserved += dim1_mask.sum().item()
dim0_num += len(dim0_mask)
dim1_num += len(dim1_mask)
if dim0_num == 0 or dim1_num == 0:
_logger.warning('no multi-dimension masks found.')
return 0
dim0_sparsity, dim1_sparsity = 1. - dim0_preserved / dim0_num, 1. - dim1_preserved / dim1_num
_logger.info('dim0 sparsity: %f', dim0_sparsity)
_logger.info('dim1 sparsity: %f', dim1_sparsity)
if dim0_sparsity == dim1_sparsity == 0.:
_logger.warning('nothing masked.')
if dim0_sparsity > 0 and dim1_sparsity > 0:
_logger.warning('both dim0 and dim1 masks found.')
return 0 if dim0_sparsity >= dim1_sparsity else 1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
def get_module_by_name(model, module_name):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list = module_name.split(".")
for name in name_list[:-1]:
if hasattr(model, name):
model = getattr(model, name)
else:
return None, None
if hasattr(model, name_list[-1]):
leaf_module = getattr(model, name_list[-1])
return model, leaf_module
else:
return None, None
......@@ -115,7 +115,7 @@ class AnalysisUtilsTest(TestCase):
pruner.export_model(ck_file, mask_file)
pruner._unwrap_model()
# Fix the mask conflict
fixed_mask = fix_mask_conflict(mask_file, net, dummy_input)
fixed_mask, _ = fix_mask_conflict(mask_file, net, dummy_input)
# use the channel dependency groud truth to check if
# fix the mask conflict successfully
......
......@@ -12,6 +12,8 @@ from torchvision.models.resnet import resnet18
from unittest import TestCase, main
from nni.compression.torch import L1FilterPruner, apply_compression_results, ModelSpeedup
from nni.compression.torch.pruning.weight_masker import WeightMasker
from nni.compression.torch.pruning.one_shot import _StructuredFilterPruner
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
......@@ -104,6 +106,74 @@ def zero_bn_bias(model):
shape = module.running_mean.data.size()
module.running_mean = torch.zeros(shape).to(device)
class L1ChannelMasker(WeightMasker):
def __init__(self, model, pruner):
self.model = model
self.pruner = pruner
def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
msg = 'module type {} is not supported!'.format(wrapper.type)
#assert wrapper.type == 'Conv2d', msg
weight = wrapper.module.weight.data
bias = None
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
bias = wrapper.module.bias.data
if wrapper.weight_mask is None:
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
else:
mask_weight = wrapper.weight_mask.clone()
if bias is not None:
if wrapper.bias_mask is None:
mask_bias = torch.ones(bias.size()).type_as(bias).detach()
else:
mask_bias = wrapper.bias_mask.clone()
else:
mask_bias = None
base_mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
num_total = weight.size(1)
num_prune = int(num_total * sparsity)
if num_total < 2 or num_prune < 1:
return base_mask
w_abs = weight.abs()
if wrapper.type == 'Conv2d':
w_abs_structured = w_abs.sum((0, 2, 3))
threshold = torch.topk(w_abs_structured, num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[None, :, None, None].expand_as(weight).type_as(weight)
return {'weight_mask': mask_weight.detach()}
else:
# Linear
assert wrapper.type == 'Linear'
w_abs_structured = w_abs.sum((0))
threshold = torch.topk(w_abs_structured, num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[None, :].expand_as(weight).type_as(weight)
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class L1ChannelPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer,
dependency_aware=dependency_aware, dummy_input=dummy_input)
def validate_config(self, model, config_list):
pass
def channel_prune(model):
config_list = [{
'sparsity': SPARSITY,
'op_types': ['Conv2d', 'Linear']
}, {
'op_names': ['conv1'],
'exclude': True
}]
pruner = L1ChannelPruner(model, config_list)
masker = L1ChannelMasker(model, pruner)
pruner.masker = masker
pruner.compress()
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self):
prune_model_l1(vgg16())
......@@ -145,10 +215,20 @@ class SpeedupTestCase(TestCase):
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3']:
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3', 'resnet50']:
kwargs = {
'pretrained': True
}
if model_name == 'resnet50':
# testing multiple groups
kwargs = {
'pretrained': False,
'groups': 4
}
Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False).to(device)
speedup_model = Model().to(device)
net = Model(**kwargs).to(device)
speedup_model = Model(**kwargs).to(device)
net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner
......@@ -165,6 +245,9 @@ class SpeedupTestCase(TestCase):
data = torch.ones(BATCH_SIZE, 3, 224, 224).to(device)
ms = ModelSpeedup(speedup_model, data, MASK_FILE)
ms.speedup_model()
speedup_model.eval()
ori_out = net(data)
speeded_out = speedup_model(data)
ori_sum = torch.sum(ori_out).item()
......@@ -174,6 +257,35 @@ class SpeedupTestCase(TestCase):
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def test_channel_prune(self):
orig_net = resnet18(num_classes=10).to(device)
channel_prune(orig_net)
state_dict = torch.load(MODEL_FILE)
orig_net = resnet18(num_classes=10).to(device)
orig_net.load_state_dict(state_dict)
apply_compression_results(orig_net, MASK_FILE)
orig_net.eval()
net = resnet18(num_classes=10).to(device)
net.load_state_dict(state_dict)
net.eval()
data = torch.randn(BATCH_SIZE, 3, 224, 224).to(device)
ms = ModelSpeedup(net, data, MASK_FILE)
ms.speedup_model()
ms.bound_model(data)
net.eval()
ori_sum = orig_net(data).abs().sum().item()
speeded_sum = net(data).abs().sum().item()
print(ori_sum, speeded_sum)
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def tearDown(self):
os.remove(MODEL_FILE)
os.remove(MASK_FILE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment