Unverified Commit 3af2ad96 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Add support for convtranspose 2d (#3013)

Support the ConvTranspose2d for speedup in this PR. The pruning of convtranspose2d may be supported in the future.
parent 622e3331
......@@ -10,6 +10,7 @@ _logger = logging.getLogger(__name__)
replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask),
'ConvTranspose2d': lambda module, mask: replace_convtranspose2d(module, mask),
'MaxPool2d': lambda module, mask: no_replace(module, mask),
'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
......@@ -22,6 +23,7 @@ replace_module = {
'Dropout3d': lambda module, mask: no_replace(module, mask)
}
def no_replace(module, mask):
"""
No need to replace
......@@ -29,6 +31,7 @@ def no_replace(module, mask):
_logger.debug("no need to replace")
return module
def replace_linear(linear, mask):
"""
Parameters
......@@ -54,11 +57,13 @@ def replace_linear(linear, mask):
out_features=linear.out_features,
bias=linear.bias is not None)
new_linear.to(linear.weight.device)
new_linear.weight.data = torch.index_select(linear.weight.data, -1, index.to(linear.weight.device))
new_linear.weight.data = torch.index_select(
linear.weight.data, -1, index.to(linear.weight.device))
if linear.bias is not None:
new_linear.bias.data.copy_(linear.bias.data)
return new_linear
def replace_batchnorm2d(norm, mask):
"""
Parameters
......@@ -87,10 +92,13 @@ def replace_batchnorm2d(norm, mask):
new_norm.weight.data = torch.index_select(norm.weight.data, 0, index)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, index)
if norm.track_running_stats:
new_norm.running_mean.data = torch.index_select(norm.running_mean.data, 0, index)
new_norm.running_var.data = torch.index_select(norm.running_var.data, 0, index)
new_norm.running_mean.data = torch.index_select(
norm.running_mean.data, 0, index)
new_norm.running_var.data = torch.index_select(
norm.running_var.data, 0, index)
return new_norm
def replace_conv2d(conv, mask):
"""
Parameters
......@@ -121,7 +129,8 @@ def replace_conv2d(conv, mask):
# remove groups for depthwise layers
assert in_channels == out_channels
groups = in_channels
_logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d", mask.module_name, in_channels, out_channels)
_logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d",
mask.module_name, in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
......@@ -136,9 +145,11 @@ def replace_conv2d(conv, mask):
tmp_weight_data = tmp_bias_data = None
if mask.output_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data, 0, out_channels_index)
tmp_weight_data = torch.index_select(
conv.weight.data, 0, out_channels_index)
if conv.bias is not None:
tmp_bias_data = torch.index_select(conv.bias.data, 0, out_channels_index)
tmp_bias_data = torch.index_select(
conv.bias.data, 0, out_channels_index)
else:
tmp_weight_data = conv.weight.data
# For the convolutional layers that have more than one group
......@@ -152,24 +163,120 @@ def replace_conv2d(conv, mask):
for groupid in range(conv.groups):
start = groupid * input_step
end = (groupid + 1) * input_step
current_input_index = list(filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
current_input_index = list(
filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
if not current_input_index:
# there is no kept channel in current group
continue
# TODO bug here, the groups is directly get from conv.groups, if the whole group is removed,
# then the number of groups in the new_conv also need to change
raise Exception(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily")
# shift the global index into the group index
current_input_index = [x-start for x in current_input_index]
# if the groups is larger than 1, the input channels of each
# group should be pruned evenly.
assert len(current_input_index) == in_channels_group, \
'Input channels of each group are not pruned evenly'
current_input_index = torch.tensor(current_input_index).to(tmp_weight_data.device) # pylint: disable=not-callable
current_input_index = torch.tensor(current_input_index).to(tmp_weight_data.device) # pylint: disable=not-callable
f_start = groupid * filter_step
f_end = (groupid + 1) * filter_step
new_conv.weight.data[f_start:f_end] = torch.index_select(tmp_weight_data[f_start:f_end], 1, current_input_index)
new_conv.weight.data[f_start:f_end] = torch.index_select(
tmp_weight_data[f_start:f_end], 1, current_input_index)
else:
new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None:
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)
new_conv.bias.data.copy_(
conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv
def replace_convtranspose2d(convtrans, mask):
"""
We need anothor replace function for
convtranspose2d, because the layout of
the weight is different from traditional
conv layers. The layout of the weight is [N_in, N_out, ksize_1, ksize_2]
Parameters
----------
convtrans : torch.nn.ConvTranspose2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.ConvTranspose2d
The new conv2d module
"""
assert isinstance(mask, ModuleMasks)
assert isinstance(convtrans, torch.nn.ConvTranspose2d)
if mask.input_mask is None:
in_channels = convtrans.in_channels
else:
in_channels_index = mask.input_mask.mask_index[1]
in_channels = in_channels_index.size(0)
if mask.output_mask is None:
out_channels = convtrans.out_channels
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size(0)
groups = convtrans.groups
# check if can remove the whole group of filters
if convtrans.in_channels == convtrans.out_channels == convtrans.groups:
# remove groups for depthwise layers
# this needs the group dependency to be fixed before the speedup
assert in_channels == out_channels
groups = in_channels
_logger.debug('Replace convtranspose2d %s with in_channels:%d out_channels:%d',
mask.module_name, in_channels, out_channels)
new_convtrans = torch.nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=convtrans.kernel_size,
stride=convtrans.stride,
padding=convtrans.padding,
dilation=convtrans.dilation,
groups=groups,
bias=convtrans.bias is not None,
padding_mode=convtrans.padding_mode)
new_convtrans.to(convtrans.weight.device)
tmp_weight_data = None
if mask.input_mask is not None:
# in convtranspose2d we need to select the input channel first
tmp_weight_data = torch.index_select(
convtrans.weight.data, 0, in_channels_index)
else:
tmp_weight_data = convtrans.weight.data
# we need to handle the output channel group by group like the conv layer
out_step = int(convtrans.out_channels / convtrans.groups)
out_channel_group = int(out_channels/groups)
new_in_per_group = int(in_channels/groups)
if mask.output_mask is not None and not(in_channels == out_channels == groups):
for groupid in range(convtrans.groups):
start = groupid * out_step
end = (groupid + 1) * out_step
current_output_index = list(
filter(lambda x: start <= x and x < end, out_channels_index.tolist()))
# we need to shift the index into the group-wise
current_output_index = [x-start for x in current_output_index]
if not current_output_index:
# No kept channel in the current group
raise Exception(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily")
assert len(current_output_index) == out_channel_group, \
'Output channel of each group should be the same after pruning'
current_output_index = torch.tensor(current_output_index).to(tmp_weight_data.device) # pylint: disable=not-callable
new_start = groupid * new_in_per_group
new_end = (groupid + 1) * new_in_per_group
new_convtrans.weight.data[new_start:new_end] = torch.index_select(
tmp_weight_data[new_start:new_end], 1, current_output_index)
else:
new_convtrans.weight.data.copy_(tmp_weight_data)
if convtrans.bias is not None:
if mask.output_mask is not None:
new_convtrans.bias.data[:] = torch.index_select(
convtrans.bias.data, 0, out_channels_index)
else:
new_convtrans.bias.data.copy_(convtrans.bias.data)
return new_convtrans
......@@ -13,6 +13,7 @@ _logger = logging.getLogger(__name__)
conv_prune_dim = -1
def set_conv_prune_dim(dim):
"""
Parameters:
......@@ -23,6 +24,7 @@ def set_conv_prune_dim(dim):
global conv_prune_dim
conv_prune_dim = dim
class CoarseMask:
"""
Coarse grained mask for a given tensor, here tensor could be weights,
......@@ -228,6 +230,7 @@ Infer input and output shape of a module/function from its weight mask
infer_from_mask = {
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_mask(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_mask(module_masks, mask),
'Linear': lambda module_masks, mask, shape: linear_mask(module_masks, mask, shape)
}
......@@ -246,6 +249,7 @@ infer_from_inshape = {
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
......@@ -277,6 +281,7 @@ Infer input and weight shape of a module/function from its output shape
"""
infer_from_outshape = {
'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_outshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_outshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
......@@ -306,6 +311,7 @@ infer_from_outshape = {
'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask)
}
def dropout_inshape(module_masks, mask):
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
......@@ -325,6 +331,7 @@ def dropout_inshape(module_masks, mask):
module_masks.set_output_mask(mask)
return module_masks.output_mask
def dropout_outshape(module_masks, mask):
if module_masks.output_mask is None:
module_masks.set_output_mask(mask)
......@@ -335,6 +342,7 @@ def dropout_outshape(module_masks, mask):
return module_masks.output_mask
def cat_inshape(module_masks, mask, cat_info, last_visited):
"""
Inference the output mask of the cat operation from the
......@@ -433,6 +441,7 @@ def add_inshape(module_masks, mask):
raise Exception('Mask conflict happenes!')
return None
def add_outshape(module_masks, mask):
"""
Inference the input mask of the add operation from the
......@@ -445,9 +454,11 @@ def add_outshape(module_masks, mask):
module_masks.set_input_mask(mask)
return mask
else:
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
return mask
def batchnorm2d_inshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
......@@ -477,6 +488,7 @@ def batchnorm2d_inshape(module_masks, mask):
module_masks.set_param_masks('bias', weight_cmask)
return mask
def batchnorm2d_outshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
......@@ -577,6 +589,7 @@ def view_inshape(module_masks, mask, shape):
module_masks.set_output_mask(output_cmask)
return output_cmask
def view_outshape(module_masks, mask, shape):
"""
Parameters
......@@ -614,12 +627,14 @@ def view_outshape(module_masks, mask, shape):
return input_cmask
def size_inshape(module_masks, mask):
"""
No need to do anything for this ```size``` op
"""
return None
def mean_inshape(module_masks, mask, shape):
"""
Similar to view operation, currently mask inference only supports
......@@ -642,6 +657,7 @@ def mean_inshape(module_masks, mask, shape):
module_masks.set_output_mask(output_cmask)
return output_cmask
def mean_outshape(module_masks, mask, shape):
"""
Similar to view operation, currently mask inference only supports
......@@ -662,6 +678,7 @@ def mean_outshape(module_masks, mask, shape):
module_masks.set_input_mask(input_cmask)
return input_cmask
def maxpool2d_inshape(module_masks, mask):
"""
Assume only the second dimension is masked
......@@ -690,6 +707,7 @@ def maxpool2d_inshape(module_masks, mask):
module_masks.set_output_mask(mask)
return mask
def maxpool2d_outshape(module_masks, mask):
"""
Assume only the second dimension is masked
......@@ -714,6 +732,7 @@ def maxpool2d_outshape(module_masks, mask):
module_masks.set_output_mask(mask)
return mask
def relu_inshape(module_masks, mask):
"""
Parameters
......@@ -737,6 +756,7 @@ def relu_inshape(module_masks, mask):
module_masks.set_output_mask(mask)
return mask
def relu_outshape(module_masks, mask):
"""
Parameters
......@@ -754,11 +774,13 @@ def relu_outshape(module_masks, mask):
assert isinstance(mask, CoarseMask)
if module_masks.output_mask is not None:
# mask conflict should be solved before speedup
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def batchnorm2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
......@@ -792,6 +814,7 @@ def batchnorm2d_mask(module_masks, mask):
module_masks.set_output_mask(output_cmask)
return input_cmask, output_cmask
def linear_mask(module_masks, mask, shape):
"""
Infer input and output shape from weight mask with limitations:
......@@ -825,6 +848,7 @@ def linear_mask(module_masks, mask, shape):
module_masks.set_input_mask(input_cmask)
return input_cmask, None
def conv2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
......@@ -863,8 +887,9 @@ def conv2d_mask(module_masks, mask):
weight_mask = mask['weight']
sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3)
index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=True)[0]
if len(index) == weight_mask.shape[dim]: # full mask
index = torch.nonzero(weight_mask.abs().sum(
sum_idx) != 0, as_tuple=True)[0]
if len(index) == weight_mask.shape[dim]: # full mask
index = None
if index is None:
......@@ -882,7 +907,8 @@ def conv2d_mask(module_masks, mask):
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask
index, weight_cmask, bias_cmask = convert_to_coarse_mask(mask, dim=conv_prune_dim)
index, weight_cmask, bias_cmask = convert_to_coarse_mask(
mask, dim=conv_prune_dim)
if index is None:
# TODO: fine grained mask speedup
......@@ -910,7 +936,8 @@ def conv2d_mask(module_masks, mask):
module_masks.set_input_mask(io_cmask)
else:
assert module_masks.input_mask == io_cmask
return module_masks.input_mask, None
return module_masks.input_mask, None
def conv2d_inshape(module_masks, mask):
"""
......@@ -972,7 +999,8 @@ def conv2d_outshape(module_masks, mask):
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
......@@ -988,3 +1016,74 @@ def conv2d_outshape(module_masks, mask):
module_masks.input_mask = mask
return mask
return None
def convtranspose2d_mask(module_masks, mask):
# TODO support the Convtranspose2d Pruning for the L1FilterPruner
raise Exception(
"Current Filter pruner cannot prune the ConvTranspose2d, will support pruning ConvTranspose2d later")
def convtranspose2d_inshape(module_masks, mask):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
else:
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert module_masks.input_mask == mask
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
def convtranspose2d_outshape(module_masks, mask):
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
if module_masks.output_mask is None:
module_masks.output_mask = mask
else:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
weight_cmask = CoarseMask(num_dim=4)
# Note the memory layout of Convtranspose2d is C_in, C_out, k1, k2
weight_cmask.add_index_mask(dim=1, index=mask.mask_index[1])
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
......@@ -9,6 +9,7 @@ from .utils import get_module_by_name
# logging.basicConfig(level = logging.DEBUG)
_logger = logging.getLogger(__name__)
def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
"""
MaskConflict fix the mask conflict for the channel dependencies
......@@ -50,6 +51,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks = padding_cat_mask.fix_mask()
return masks, fix_channel_mask.conv_prune_dim
class MaskFix:
def __init__(self, masks, model=None, dummy_input=None, traced=None):
# check if the parameters are valid
......@@ -74,6 +76,7 @@ class MaskFix:
"""
torch.save(self.masks, path)
class CatMaskPadding(MaskFix):
def __init__(self, masks, model, dummy_input=None, traced=None):
"""
......@@ -100,7 +103,8 @@ class CatMaskPadding(MaskFix):
super(CatMaskPadding, self).__init__(masks, model, dummy_input, traced)
def fix_mask(self):
cat_padding_depen = CatPaddingDependency(self.model, self.dummy_input, self.traced)
cat_padding_depen = CatPaddingDependency(
self.model, self.dummy_input, self.traced)
name_to_module = {}
for name, module in self.model.named_modules():
name_to_module[name] = module
......@@ -131,11 +135,10 @@ class CatMaskPadding(MaskFix):
# module.bias may be None
b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight':w_mask, 'bias':b_mask}
self.masks[layer] = {'weight': w_mask, 'bias': b_mask}
return self.masks
class GroupMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None):
"""
......@@ -154,8 +157,8 @@ class GroupMaskConflict(MaskFix):
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super(GroupMaskConflict, self).__init__(masks, model, dummy_input, traced)
super(GroupMaskConflict, self).__init__(
masks, model, dummy_input, traced)
def fix_mask(self):
"""
......@@ -163,7 +166,8 @@ class GroupMaskConflict(MaskFix):
has group dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
group_depen = GroupDependency(self.model, self.dummy_input, self.traced)
group_depen = GroupDependency(
self.model, self.dummy_input, self.traced)
depens = group_depen.dependency
_logger.info(depens)
for layername in depens:
......@@ -174,8 +178,10 @@ class GroupMaskConflict(MaskFix):
w_mask = self.masks[layername]['weight']
shape = w_mask.size()
count = np.prod(shape[1:])
all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist()
all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist()
all_ones = (w_mask.flatten(1).sum(-1) ==
count).nonzero().squeeze(1).tolist()
all_zeros = (w_mask.flatten(1).sum(-1) ==
0).nonzero().squeeze(1).tolist()
if len(all_ones) + len(all_zeros) < w_mask.size(0):
# In fine-grained pruning, skip this layer
_logger.info('Layers %s using fine-grained pruning', layername)
......@@ -190,7 +196,8 @@ class GroupMaskConflict(MaskFix):
for i in range(group):
_start = step * i
_end = step * (i+1)
_tmp_list = list(filter(lambda x: _start <= x and x < _end, all_zeros))
_tmp_list = list(
filter(lambda x: _start <= x and x < _end, all_zeros))
group_masked.append(_tmp_list)
mini_masked = min([len(x) for x in group_masked])
for gm in group_masked:
......@@ -198,13 +205,13 @@ class GroupMaskConflict(MaskFix):
# To keep the output channel number still being divisible to
# groups, we set the masks of following filters to be zero.
pos = gm[i]
self.masks[layername]['weight'][pos] = torch.ones(shape[1:])
if hasattr(self.masks[layername], 'bias'):
self.masks[layername]['weight'][pos] = torch.ones(
shape[1:])
if 'bias' in self.masks[layername] and self.masks[layername]['bias'] is not None:
self.masks[layername]['bias'][pos] = 1
return self.masks
class ChannelMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None):
"""
......@@ -223,7 +230,8 @@ class ChannelMaskConflict(MaskFix):
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super(ChannelMaskConflict, self).__init__(masks, model, dummy_input, traced)
super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model)
_logger.info('detected conv prune dim: %s', self.conv_prune_dim)
......@@ -235,9 +243,11 @@ class ChannelMaskConflict(MaskFix):
are supported.
"""
if self.conv_prune_dim == 0:
channel_depen = ChannelDependency(self.model, self.dummy_input, self.traced)
channel_depen = ChannelDependency(
self.model, self.dummy_input, self.traced)
else:
channel_depen = InputChannelDependency(self.model, self.dummy_input, self.traced)
channel_depen = InputChannelDependency(
self.model, self.dummy_input, self.traced)
depen_sets = channel_depen.dependency_sets
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)
for dset in depen_sets:
......@@ -262,17 +272,29 @@ class ChannelMaskConflict(MaskFix):
channel_masks.append((mask.abs().sum(0) != 0).int())
elif type(m).__name__ == 'BatchNorm2d':
channel_masks.append(mask.int())
elif type(m).__name__ == 'ConvTranspose2d':
# convtranspose have difference memory layout, so that we need create
# a tmp_sum_idx for conv_transpose
tmp_sum_idx = (
0, 2, 3) if self.conv_prune_dim == 0 else (1, 2, 3)
channel_mask = (mask.abs().sum(tmp_sum_idx) != 0).int()
channel_masks.append(channel_mask)
if (channel_mask.sum() * (mask.numel() / mask.shape[1-self.conv_prune_dim])).item() != (mask > 0).sum().item():
fine_grained = True
else:
raise RuntimeError(f'unsupported module type: {type(m).__name__}')
raise RuntimeError(
f'unsupported module type: {type(m).__name__}')
else:
# no mask means not pruned, equivlent to full masks
channel_masks.append(None)
if fine_grained:
_logger.info('fine-grained mask detected, skip solving conflict for this set: %s', dset)
_logger.info(
'fine-grained mask detected, skip solving conflict for this set: %s', dset)
continue
if all(x is None for x in channel_masks):
continue
num_channels_list = [len(x) for x in channel_masks if x is not None]
num_channels_list = [len(x)
for x in channel_masks if x is not None]
# number of channels in same set should be identical
assert len(set(num_channels_list)) == 1
num_channels = num_channels_list[0]
......@@ -284,7 +306,8 @@ class ChannelMaskConflict(MaskFix):
# merge masks with 'or'
merged_channel_mask = channel_masks[0].clone()
for i in range(1, len(channel_masks)):
merged_channel_mask = ((merged_channel_mask + channel_masks[i]) != 0).int()
merged_channel_mask = (
(merged_channel_mask + channel_masks[i]) != 0).int()
merged_index = torch.nonzero(merged_channel_mask, as_tuple=True)[0]
......@@ -305,16 +328,19 @@ class ChannelMaskConflict(MaskFix):
elif type(m).__name__ == 'BatchNorm2d':
new_mask = merged_index.type_as(orig_mask)
else:
raise RuntimeError(f'unsupported module type: {type(m).__name__}')
raise RuntimeError(
f'unsupported module type: {type(m).__name__}')
self.masks[name]['weight'] = new_mask
if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None:
if type(m).__name__ == 'Conv2d':
assert self.conv_prune_dim == 0
self.masks[name]['bias'] = merged_channel_mask.type_as(self.masks[name]['bias'])
self.masks[name]['bias'] = merged_channel_mask.type_as(
self.masks[name]['bias'])
return self.masks
def detect_mask_prune_dim(masks, model):
"""
Detect how the masks of convolutional layers are pruned.
......@@ -358,7 +384,8 @@ def detect_mask_prune_dim(masks, model):
_logger.warning('no multi-dimension masks found.')
return 0
dim0_sparsity, dim1_sparsity = 1. - dim0_preserved / dim0_num, 1. - dim1_preserved / dim1_num
dim0_sparsity, dim1_sparsity = 1. - dim0_preserved / \
dim0_num, 1. - dim1_preserved / dim1_num
_logger.info('dim0 sparsity: %f', dim0_sparsity)
_logger.info('dim1 sparsity: %f', dim1_sparsity)
......
......@@ -4,13 +4,16 @@
import csv
import logging
__all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency', 'InputChannelDependency']
__all__ = ['ChannelDependency', 'GroupDependency',
'CatPaddingDependency', 'InputChannelDependency']
CONV_TYPE = 'aten::_convolution'
ADD_TYPES = ['aten::add', 'aten::add_']
CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency')
RESHAPE_OPS = [CAT_TYPE, 'aten::view', 'aten::reshape', 'aten::flatten', 'aten::mean']
RESHAPE_OPS = [CAT_TYPE, 'aten::view',
'aten::reshape', 'aten::flatten', 'aten::mean']
class Dependency:
def __init__(self, model=None, dummy_input=None, traced_model=None):
......@@ -34,6 +37,7 @@ class Dependency:
def export(self, filepath):
raise NotImplementedError
class ChannelDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
......@@ -50,7 +54,8 @@ class ChannelDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super(ChannelDependency, self).__init__(model, dummy_input, traced_model)
super(ChannelDependency, self).__init__(
model, dummy_input, traced_model)
def _get_parent_layers(self, node):
"""
......@@ -71,7 +76,7 @@ class ChannelDependency(Dependency):
queue.append(node)
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear':
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv
parent_layers.append(curnode.name)
continue
......@@ -119,7 +124,6 @@ class ChannelDependency(Dependency):
for _node in dependency_set:
self.dependency[_node] = dependency_set
def export(self, filepath):
"""
export the channel dependencies as a csv file.
......@@ -185,6 +189,7 @@ class ChannelDependency(Dependency):
d_sets.append(tmp_set)
return d_sets
def reshape_break_channel_dependency(op_node):
"""
The reshape operations such as (reshape, view, flatten) may break
......@@ -213,6 +218,7 @@ def reshape_break_channel_dependency(op_node):
out_channel = out_shape[1]
return in_channel != out_channel
class InputChannelDependency(ChannelDependency):
"""
Some pruners may prune the input channel of the convolutional
......@@ -242,7 +248,8 @@ class InputChannelDependency(ChannelDependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super(InputChannelDependency, self).__init__(model, dummy_input, traced_model)
super(InputChannelDependency, self).__init__(
model, dummy_input, traced_model)
def _get_following_convs(self, tensor):
queue = []
......@@ -250,14 +257,14 @@ class InputChannelDependency(ChannelDependency):
queue.extend(self.graph.input_to_node[tensor])
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear':
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv
key_layers.append(curnode.name)
continue
elif curnode.op_type in RESHAPE_OPS:
# check if the reshape operation will break the channel dependency
if reshape_break_channel_dependency(curnode):
# reshape operations also breaks the dependency relationship
# reshape operations also breaks the dependency relationship
continue
successors = self.graph.find_successors(curnode.unique_name)
successors = [self.graph.name_to_node[name] for name in successors]
......@@ -290,7 +297,8 @@ class InputChannelDependency(ChannelDependency):
class CatPaddingDependency(ChannelDependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
super(CatPaddingDependency, self).__init__(model, dummy_input, traced_model)
super(CatPaddingDependency, self).__init__(
model, dummy_input, traced_model)
def build_dependency(self):
"""
......@@ -347,6 +355,7 @@ class CatPaddingDependency(ChannelDependency):
row.extend(list(layers))
csv_w.writerow(row)
class GroupDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
......@@ -388,7 +397,7 @@ class GroupDependency(Dependency):
queue = predeessors
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d':
if curnode.op_type == 'Conv2d' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv
parent_layers.append(curnode.name)
continue
......@@ -412,7 +421,8 @@ class GroupDependency(Dependency):
group : int
the number of the groups of the target conv layer.
"""
cpp_conv = list(filter(lambda x: x.kind() == CONV_TYPE, node_group.node_cpps))
cpp_conv = list(filter(lambda x: x.kind() ==
CONV_TYPE, node_group.node_cpps))
assert len(cpp_conv) == 1
cpp_conv = cpp_conv[0]
inputs = list(cpp_conv.inputs())
......@@ -442,12 +452,14 @@ class GroupDependency(Dependency):
filters should be divisible to.
"""
for node in self.graph.nodes_py.nodes_op:
if node.op_type == 'Conv2d':
if node.op_type == 'Conv2d' or node.op_type == 'ConvTranspose2d':
group = self._get_conv_groups(node)
if node.name in self.dependency:
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
self.dependency[node.name] = max(self.dependency[node.name], group)
self.dependency[node.name] = max(
self.dependency[node.name], group)
else:
self.dependency[node.name] = group
if group > 1:
......@@ -456,7 +468,8 @@ class GroupDependency(Dependency):
parent_convs = self._get_parent_convs(node)
for parent in parent_convs:
if parent in self.dependency:
self.dependency[parent] = max(self.dependency[parent], group)
self.dependency[parent] = max(
self.dependency[parent], group)
else:
self.dependency[parent] = group
return self.dependency
......@@ -484,6 +497,7 @@ class GroupDependency(Dependency):
for name in self.dependency:
group = self.dependency[name]
csv_w.writerow([name, group])
@property
def dependency_sets(self):
return self.dependency
......@@ -30,13 +30,17 @@ RELATIVE_THRESHOLD = 0.01
# an absolute threshold to determine whether the final result is correct.
# The error should meet the RELATIVE_THREHOLD or the ABSOLUTE_THRESHOLD.
ABSOLUTE_THRESHOLD = 0.0001
class BackboneModel1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1, 1)
def forward(self, x):
return self.conv1(x)
class BackboneModel2(torch.nn.Module):
def __init__(self):
super().__init__()
......@@ -53,32 +57,58 @@ class BackboneModel2(torch.nn.Module):
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class BigModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.backbone1 = BackboneModel1()
self.backbone2 = BackboneModel2()
self.fc3 = nn.Sequential(
self.fc3 = nn.Sequential(
nn.Linear(10, 10),
nn.BatchNorm1d(10),
nn.ReLU(inplace=True),
nn.Linear(10, 2)
)
def forward(self, x):
x = self.backbone1(x)
x = self.backbone2(x)
x = self.fc3(x)
return x
class TransposeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 20, 5)
self.conv2 = nn.ConvTranspose2d(20, 50, 5, groups=2)
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
self.fc1 = nn.Linear(8 * 8 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
# x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
# x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
dummy_input = torch.randn(2, 1, 28, 28)
SPARSITY = 0.5
MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth'
def prune_model_l1(model):
config_list = [{
'sparsity': SPARSITY,
......@@ -88,6 +118,7 @@ def prune_model_l1(model):
pruner.compress()
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
def generate_random_sparsity(model):
cfg_list = []
for name, module in model.named_modules():
......@@ -97,18 +128,20 @@ def generate_random_sparsity(model):
'sparsity': sparsity})
return cfg_list
def zero_bn_bias(model):
with torch.no_grad():
for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d) \
or isinstance(module, nn.BatchNorm3d) \
or isinstance(module, nn.BatchNorm1d):
or isinstance(module, nn.BatchNorm3d) \
or isinstance(module, nn.BatchNorm1d):
shape = module.bias.data.size()
device = module.bias.device
module.bias.data = torch.zeros(shape).to(device)
shape = module.running_mean.data.size()
module.running_mean = torch.zeros(shape).to(device)
class L1ChannelMasker(WeightMasker):
def __init__(self, model, pruner):
self.model = model
......@@ -143,21 +176,27 @@ class L1ChannelMasker(WeightMasker):
w_abs = weight.abs()
if wrapper.type == 'Conv2d':
w_abs_structured = w_abs.sum((0, 2, 3))
threshold = torch.topk(w_abs_structured, num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[None, :, None, None].expand_as(weight).type_as(weight)
threshold = torch.topk(
w_abs_structured, num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[
None, :, None, None].expand_as(weight).type_as(weight)
return {'weight_mask': mask_weight.detach()}
else:
# Linear
assert wrapper.type == 'Linear'
w_abs_structured = w_abs.sum((0))
threshold = torch.topk(w_abs_structured, num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[None, :].expand_as(weight).type_as(weight)
threshold = torch.topk(
w_abs_structured, num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[
None, :].expand_as(weight).type_as(weight)
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class L1ChannelPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer,
dependency_aware=dependency_aware, dummy_input=dummy_input)
def validate_config(self, model, config_list):
pass
......@@ -177,6 +216,7 @@ def channel_prune(model):
pruner.compress()
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self):
prune_model_l1(vgg16())
......@@ -187,8 +227,10 @@ class SpeedupTestCase(TestCase):
orig_model = vgg16()
assert model.training
assert model.features[2].out_channels == int(orig_model.features[2].out_channels * SPARSITY)
assert model.classifier[0].in_features == int(orig_model.classifier[0].in_features * SPARSITY)
assert model.features[2].out_channels == int(
orig_model.features[2].out_channels * SPARSITY)
assert model.classifier[0].in_features == int(
orig_model.classifier[0].in_features * SPARSITY)
def test_speedup_bigmodel(self):
prune_model_l1(BigModel())
......@@ -205,23 +247,55 @@ class SpeedupTestCase(TestCase):
model.eval()
speedup_out = model(dummy_input)
if not torch.allclose(mask_out, speedup_out, atol=1e-07):
print('input:', dummy_input.size(), torch.abs(dummy_input).sum((2,3)))
print('input:', dummy_input.size(),
torch.abs(dummy_input).sum((2, 3)))
print('mask_out:', mask_out)
print('speedup_out:', speedup_out)
raise RuntimeError('model speedup inference result is incorrect!')
orig_model = BigModel()
assert model.backbone2.conv1.out_channels == int(orig_model.backbone2.conv1.out_channels * SPARSITY)
assert model.backbone2.conv2.in_channels == int(orig_model.backbone2.conv2.in_channels * SPARSITY)
assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY)
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
assert model.backbone2.conv1.out_channels == int(
orig_model.backbone2.conv1.out_channels * SPARSITY)
assert model.backbone2.conv2.in_channels == int(
orig_model.backbone2.conv2.in_channels * SPARSITY)
assert model.backbone2.conv2.out_channels == int(
orig_model.backbone2.conv2.out_channels * SPARSITY)
assert model.backbone2.fc1.in_features == int(
orig_model.backbone2.fc1.in_features * SPARSITY)
def test_convtranspose_model(self):
ori_model = TransposeModel()
dummy_input = torch.rand(1, 3, 8, 8)
config_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}]
pruner = L1FilterPruner(ori_model, config_list)
pruner.compress()
ori_model(dummy_input)
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
new_model = TransposeModel()
state_dict = torch.load(MODEL_FILE)
new_model.load_state_dict(state_dict)
ms = ModelSpeedup(new_model, dummy_input, MASK_FILE)
ms.speedup_model()
zero_bn_bias(ori_model)
zero_bn_bias(new_model)
ori_out = ori_model(dummy_input)
new_out = new_model(dummy_input)
ori_sum = torch.sum(ori_out)
speeded_sum = torch.sum(new_out)
print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(ori_sum, speeded_sum))
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
# FIXME: This test case might fail randomly, no idea why
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3', 'resnet50']:
for model_name in ['resnet18', 'squeezenet1_1',
'mobilenet_v2', 'densenet121',
# 'inception_v3' inception is too large and may fail the pipeline
'densenet169', 'resnet50']:
kwargs = {
'pretrained': True
}
......@@ -235,7 +309,7 @@ class SpeedupTestCase(TestCase):
Model = getattr(models, model_name)
net = Model(**kwargs).to(device)
speedup_model = Model(**kwargs).to(device)
net.eval() # this line is necessary
net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner
cfgs = generate_random_sparsity(net)
......@@ -258,8 +332,10 @@ class SpeedupTestCase(TestCase):
speeded_out = speedup_model(data)
ori_sum = torch.sum(ori_out).item()
speeded_sum = torch.sum(speeded_out).item()
print('Sum of the output of %s (before speedup):'%model_name, ori_sum)
print('Sum of the output of %s (after speedup):'%model_name, speeded_sum)
print('Sum of the output of %s (before speedup):' %
model_name, ori_sum)
print('Sum of the output of %s (after speedup):' %
model_name, speeded_sum)
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
......@@ -296,5 +372,6 @@ class SpeedupTestCase(TestCase):
os.remove(MODEL_FILE)
os.remove(MASK_FILE)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment