Unverified Commit 3af2ad96 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Add support for convtranspose 2d (#3013)

Support the ConvTranspose2d for speedup in this PR. The pruning of convtranspose2d may be supported in the future.
parent 622e3331
...@@ -10,6 +10,7 @@ _logger = logging.getLogger(__name__) ...@@ -10,6 +10,7 @@ _logger = logging.getLogger(__name__)
replace_module = { replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask), 'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask), 'Conv2d': lambda module, mask: replace_conv2d(module, mask),
'ConvTranspose2d': lambda module, mask: replace_convtranspose2d(module, mask),
'MaxPool2d': lambda module, mask: no_replace(module, mask), 'MaxPool2d': lambda module, mask: no_replace(module, mask),
'AvgPool2d': lambda module, mask: no_replace(module, mask), 'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask), 'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
...@@ -22,6 +23,7 @@ replace_module = { ...@@ -22,6 +23,7 @@ replace_module = {
'Dropout3d': lambda module, mask: no_replace(module, mask) 'Dropout3d': lambda module, mask: no_replace(module, mask)
} }
def no_replace(module, mask): def no_replace(module, mask):
""" """
No need to replace No need to replace
...@@ -29,6 +31,7 @@ def no_replace(module, mask): ...@@ -29,6 +31,7 @@ def no_replace(module, mask):
_logger.debug("no need to replace") _logger.debug("no need to replace")
return module return module
def replace_linear(linear, mask): def replace_linear(linear, mask):
""" """
Parameters Parameters
...@@ -54,11 +57,13 @@ def replace_linear(linear, mask): ...@@ -54,11 +57,13 @@ def replace_linear(linear, mask):
out_features=linear.out_features, out_features=linear.out_features,
bias=linear.bias is not None) bias=linear.bias is not None)
new_linear.to(linear.weight.device) new_linear.to(linear.weight.device)
new_linear.weight.data = torch.index_select(linear.weight.data, -1, index.to(linear.weight.device)) new_linear.weight.data = torch.index_select(
linear.weight.data, -1, index.to(linear.weight.device))
if linear.bias is not None: if linear.bias is not None:
new_linear.bias.data.copy_(linear.bias.data) new_linear.bias.data.copy_(linear.bias.data)
return new_linear return new_linear
def replace_batchnorm2d(norm, mask): def replace_batchnorm2d(norm, mask):
""" """
Parameters Parameters
...@@ -87,10 +92,13 @@ def replace_batchnorm2d(norm, mask): ...@@ -87,10 +92,13 @@ def replace_batchnorm2d(norm, mask):
new_norm.weight.data = torch.index_select(norm.weight.data, 0, index) new_norm.weight.data = torch.index_select(norm.weight.data, 0, index)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, index) new_norm.bias.data = torch.index_select(norm.bias.data, 0, index)
if norm.track_running_stats: if norm.track_running_stats:
new_norm.running_mean.data = torch.index_select(norm.running_mean.data, 0, index) new_norm.running_mean.data = torch.index_select(
new_norm.running_var.data = torch.index_select(norm.running_var.data, 0, index) norm.running_mean.data, 0, index)
new_norm.running_var.data = torch.index_select(
norm.running_var.data, 0, index)
return new_norm return new_norm
def replace_conv2d(conv, mask): def replace_conv2d(conv, mask):
""" """
Parameters Parameters
...@@ -121,7 +129,8 @@ def replace_conv2d(conv, mask): ...@@ -121,7 +129,8 @@ def replace_conv2d(conv, mask):
# remove groups for depthwise layers # remove groups for depthwise layers
assert in_channels == out_channels assert in_channels == out_channels
groups = in_channels groups = in_channels
_logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d", mask.module_name, in_channels, out_channels) _logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d",
mask.module_name, in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels, new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=conv.kernel_size, kernel_size=conv.kernel_size,
...@@ -136,9 +145,11 @@ def replace_conv2d(conv, mask): ...@@ -136,9 +145,11 @@ def replace_conv2d(conv, mask):
tmp_weight_data = tmp_bias_data = None tmp_weight_data = tmp_bias_data = None
if mask.output_mask is not None: if mask.output_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data, 0, out_channels_index) tmp_weight_data = torch.index_select(
conv.weight.data, 0, out_channels_index)
if conv.bias is not None: if conv.bias is not None:
tmp_bias_data = torch.index_select(conv.bias.data, 0, out_channels_index) tmp_bias_data = torch.index_select(
conv.bias.data, 0, out_channels_index)
else: else:
tmp_weight_data = conv.weight.data tmp_weight_data = conv.weight.data
# For the convolutional layers that have more than one group # For the convolutional layers that have more than one group
...@@ -152,24 +163,120 @@ def replace_conv2d(conv, mask): ...@@ -152,24 +163,120 @@ def replace_conv2d(conv, mask):
for groupid in range(conv.groups): for groupid in range(conv.groups):
start = groupid * input_step start = groupid * input_step
end = (groupid + 1) * input_step end = (groupid + 1) * input_step
current_input_index = list(filter(lambda x: start <= x and x < end, in_channels_index.tolist())) current_input_index = list(
filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
if not current_input_index: if not current_input_index:
# there is no kept channel in current group # there is no kept channel in current group
continue # TODO bug here, the groups is directly get from conv.groups, if the whole group is removed,
# then the number of groups in the new_conv also need to change
raise Exception(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily")
# shift the global index into the group index # shift the global index into the group index
current_input_index = [x-start for x in current_input_index] current_input_index = [x-start for x in current_input_index]
# if the groups is larger than 1, the input channels of each # if the groups is larger than 1, the input channels of each
# group should be pruned evenly. # group should be pruned evenly.
assert len(current_input_index) == in_channels_group, \ assert len(current_input_index) == in_channels_group, \
'Input channels of each group are not pruned evenly' 'Input channels of each group are not pruned evenly'
current_input_index = torch.tensor(current_input_index).to(tmp_weight_data.device) # pylint: disable=not-callable current_input_index = torch.tensor(current_input_index).to(tmp_weight_data.device) # pylint: disable=not-callable
f_start = groupid * filter_step f_start = groupid * filter_step
f_end = (groupid + 1) * filter_step f_end = (groupid + 1) * filter_step
new_conv.weight.data[f_start:f_end] = torch.index_select(tmp_weight_data[f_start:f_end], 1, current_input_index) new_conv.weight.data[f_start:f_end] = torch.index_select(
tmp_weight_data[f_start:f_end], 1, current_input_index)
else: else:
new_conv.weight.data.copy_(tmp_weight_data) new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None: if conv.bias is not None:
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data) new_conv.bias.data.copy_(
conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv return new_conv
def replace_convtranspose2d(convtrans, mask):
"""
We need anothor replace function for
convtranspose2d, because the layout of
the weight is different from traditional
conv layers. The layout of the weight is [N_in, N_out, ksize_1, ksize_2]
Parameters
----------
convtrans : torch.nn.ConvTranspose2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.ConvTranspose2d
The new conv2d module
"""
assert isinstance(mask, ModuleMasks)
assert isinstance(convtrans, torch.nn.ConvTranspose2d)
if mask.input_mask is None:
in_channels = convtrans.in_channels
else:
in_channels_index = mask.input_mask.mask_index[1]
in_channels = in_channels_index.size(0)
if mask.output_mask is None:
out_channels = convtrans.out_channels
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size(0)
groups = convtrans.groups
# check if can remove the whole group of filters
if convtrans.in_channels == convtrans.out_channels == convtrans.groups:
# remove groups for depthwise layers
# this needs the group dependency to be fixed before the speedup
assert in_channels == out_channels
groups = in_channels
_logger.debug('Replace convtranspose2d %s with in_channels:%d out_channels:%d',
mask.module_name, in_channels, out_channels)
new_convtrans = torch.nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=convtrans.kernel_size,
stride=convtrans.stride,
padding=convtrans.padding,
dilation=convtrans.dilation,
groups=groups,
bias=convtrans.bias is not None,
padding_mode=convtrans.padding_mode)
new_convtrans.to(convtrans.weight.device)
tmp_weight_data = None
if mask.input_mask is not None:
# in convtranspose2d we need to select the input channel first
tmp_weight_data = torch.index_select(
convtrans.weight.data, 0, in_channels_index)
else:
tmp_weight_data = convtrans.weight.data
# we need to handle the output channel group by group like the conv layer
out_step = int(convtrans.out_channels / convtrans.groups)
out_channel_group = int(out_channels/groups)
new_in_per_group = int(in_channels/groups)
if mask.output_mask is not None and not(in_channels == out_channels == groups):
for groupid in range(convtrans.groups):
start = groupid * out_step
end = (groupid + 1) * out_step
current_output_index = list(
filter(lambda x: start <= x and x < end, out_channels_index.tolist()))
# we need to shift the index into the group-wise
current_output_index = [x-start for x in current_output_index]
if not current_output_index:
# No kept channel in the current group
raise Exception(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily")
assert len(current_output_index) == out_channel_group, \
'Output channel of each group should be the same after pruning'
current_output_index = torch.tensor(current_output_index).to(tmp_weight_data.device) # pylint: disable=not-callable
new_start = groupid * new_in_per_group
new_end = (groupid + 1) * new_in_per_group
new_convtrans.weight.data[new_start:new_end] = torch.index_select(
tmp_weight_data[new_start:new_end], 1, current_output_index)
else:
new_convtrans.weight.data.copy_(tmp_weight_data)
if convtrans.bias is not None:
if mask.output_mask is not None:
new_convtrans.bias.data[:] = torch.index_select(
convtrans.bias.data, 0, out_channels_index)
else:
new_convtrans.bias.data.copy_(convtrans.bias.data)
return new_convtrans
...@@ -13,6 +13,7 @@ _logger = logging.getLogger(__name__) ...@@ -13,6 +13,7 @@ _logger = logging.getLogger(__name__)
conv_prune_dim = -1 conv_prune_dim = -1
def set_conv_prune_dim(dim): def set_conv_prune_dim(dim):
""" """
Parameters: Parameters:
...@@ -23,6 +24,7 @@ def set_conv_prune_dim(dim): ...@@ -23,6 +24,7 @@ def set_conv_prune_dim(dim):
global conv_prune_dim global conv_prune_dim
conv_prune_dim = dim conv_prune_dim = dim
class CoarseMask: class CoarseMask:
""" """
Coarse grained mask for a given tensor, here tensor could be weights, Coarse grained mask for a given tensor, here tensor could be weights,
...@@ -228,6 +230,7 @@ Infer input and output shape of a module/function from its weight mask ...@@ -228,6 +230,7 @@ Infer input and output shape of a module/function from its weight mask
infer_from_mask = { infer_from_mask = {
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_mask(module_masks, mask), 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_mask(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask), 'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_mask(module_masks, mask),
'Linear': lambda module_masks, mask, shape: linear_mask(module_masks, mask, shape) 'Linear': lambda module_masks, mask, shape: linear_mask(module_masks, mask, shape)
} }
...@@ -246,6 +249,7 @@ infer_from_inshape = { ...@@ -246,6 +249,7 @@ infer_from_inshape = {
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask), 'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
...@@ -277,6 +281,7 @@ Infer input and weight shape of a module/function from its output shape ...@@ -277,6 +281,7 @@ Infer input and weight shape of a module/function from its output shape
""" """
infer_from_outshape = { infer_from_outshape = {
'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask), 'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_outshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_outshape(module_masks, mask), 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_outshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), 'MaxPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
...@@ -306,6 +311,7 @@ infer_from_outshape = { ...@@ -306,6 +311,7 @@ infer_from_outshape = {
'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask) 'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask)
} }
def dropout_inshape(module_masks, mask): def dropout_inshape(module_masks, mask):
if module_masks.input_mask is None: if module_masks.input_mask is None:
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
...@@ -325,6 +331,7 @@ def dropout_inshape(module_masks, mask): ...@@ -325,6 +331,7 @@ def dropout_inshape(module_masks, mask):
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
return module_masks.output_mask return module_masks.output_mask
def dropout_outshape(module_masks, mask): def dropout_outshape(module_masks, mask):
if module_masks.output_mask is None: if module_masks.output_mask is None:
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
...@@ -335,6 +342,7 @@ def dropout_outshape(module_masks, mask): ...@@ -335,6 +342,7 @@ def dropout_outshape(module_masks, mask):
return module_masks.output_mask return module_masks.output_mask
def cat_inshape(module_masks, mask, cat_info, last_visited): def cat_inshape(module_masks, mask, cat_info, last_visited):
""" """
Inference the output mask of the cat operation from the Inference the output mask of the cat operation from the
...@@ -433,6 +441,7 @@ def add_inshape(module_masks, mask): ...@@ -433,6 +441,7 @@ def add_inshape(module_masks, mask):
raise Exception('Mask conflict happenes!') raise Exception('Mask conflict happenes!')
return None return None
def add_outshape(module_masks, mask): def add_outshape(module_masks, mask):
""" """
Inference the input mask of the add operation from the Inference the input mask of the add operation from the
...@@ -445,9 +454,11 @@ def add_outshape(module_masks, mask): ...@@ -445,9 +454,11 @@ def add_outshape(module_masks, mask):
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
return mask return mask
else: else:
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1]) assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
return mask return mask
def batchnorm2d_inshape(module_masks, mask): def batchnorm2d_inshape(module_masks, mask):
""" """
We assume only the second dimension has coarse grained mask We assume only the second dimension has coarse grained mask
...@@ -477,6 +488,7 @@ def batchnorm2d_inshape(module_masks, mask): ...@@ -477,6 +488,7 @@ def batchnorm2d_inshape(module_masks, mask):
module_masks.set_param_masks('bias', weight_cmask) module_masks.set_param_masks('bias', weight_cmask)
return mask return mask
def batchnorm2d_outshape(module_masks, mask): def batchnorm2d_outshape(module_masks, mask):
""" """
We assume only the second dimension has coarse grained mask We assume only the second dimension has coarse grained mask
...@@ -577,6 +589,7 @@ def view_inshape(module_masks, mask, shape): ...@@ -577,6 +589,7 @@ def view_inshape(module_masks, mask, shape):
module_masks.set_output_mask(output_cmask) module_masks.set_output_mask(output_cmask)
return output_cmask return output_cmask
def view_outshape(module_masks, mask, shape): def view_outshape(module_masks, mask, shape):
""" """
Parameters Parameters
...@@ -614,12 +627,14 @@ def view_outshape(module_masks, mask, shape): ...@@ -614,12 +627,14 @@ def view_outshape(module_masks, mask, shape):
return input_cmask return input_cmask
def size_inshape(module_masks, mask): def size_inshape(module_masks, mask):
""" """
No need to do anything for this ```size``` op No need to do anything for this ```size``` op
""" """
return None return None
def mean_inshape(module_masks, mask, shape): def mean_inshape(module_masks, mask, shape):
""" """
Similar to view operation, currently mask inference only supports Similar to view operation, currently mask inference only supports
...@@ -642,6 +657,7 @@ def mean_inshape(module_masks, mask, shape): ...@@ -642,6 +657,7 @@ def mean_inshape(module_masks, mask, shape):
module_masks.set_output_mask(output_cmask) module_masks.set_output_mask(output_cmask)
return output_cmask return output_cmask
def mean_outshape(module_masks, mask, shape): def mean_outshape(module_masks, mask, shape):
""" """
Similar to view operation, currently mask inference only supports Similar to view operation, currently mask inference only supports
...@@ -662,6 +678,7 @@ def mean_outshape(module_masks, mask, shape): ...@@ -662,6 +678,7 @@ def mean_outshape(module_masks, mask, shape):
module_masks.set_input_mask(input_cmask) module_masks.set_input_mask(input_cmask)
return input_cmask return input_cmask
def maxpool2d_inshape(module_masks, mask): def maxpool2d_inshape(module_masks, mask):
""" """
Assume only the second dimension is masked Assume only the second dimension is masked
...@@ -690,6 +707,7 @@ def maxpool2d_inshape(module_masks, mask): ...@@ -690,6 +707,7 @@ def maxpool2d_inshape(module_masks, mask):
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
return mask return mask
def maxpool2d_outshape(module_masks, mask): def maxpool2d_outshape(module_masks, mask):
""" """
Assume only the second dimension is masked Assume only the second dimension is masked
...@@ -714,6 +732,7 @@ def maxpool2d_outshape(module_masks, mask): ...@@ -714,6 +732,7 @@ def maxpool2d_outshape(module_masks, mask):
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
return mask return mask
def relu_inshape(module_masks, mask): def relu_inshape(module_masks, mask):
""" """
Parameters Parameters
...@@ -737,6 +756,7 @@ def relu_inshape(module_masks, mask): ...@@ -737,6 +756,7 @@ def relu_inshape(module_masks, mask):
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
return mask return mask
def relu_outshape(module_masks, mask): def relu_outshape(module_masks, mask):
""" """
Parameters Parameters
...@@ -754,11 +774,13 @@ def relu_outshape(module_masks, mask): ...@@ -754,11 +774,13 @@ def relu_outshape(module_masks, mask):
assert isinstance(mask, CoarseMask) assert isinstance(mask, CoarseMask)
if module_masks.output_mask is not None: if module_masks.output_mask is not None:
# mask conflict should be solved before speedup # mask conflict should be solved before speedup
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1]) assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
return mask return mask
def batchnorm2d_mask(module_masks, mask): def batchnorm2d_mask(module_masks, mask):
""" """
Infer input and output shape from weight mask Infer input and output shape from weight mask
...@@ -792,6 +814,7 @@ def batchnorm2d_mask(module_masks, mask): ...@@ -792,6 +814,7 @@ def batchnorm2d_mask(module_masks, mask):
module_masks.set_output_mask(output_cmask) module_masks.set_output_mask(output_cmask)
return input_cmask, output_cmask return input_cmask, output_cmask
def linear_mask(module_masks, mask, shape): def linear_mask(module_masks, mask, shape):
""" """
Infer input and output shape from weight mask with limitations: Infer input and output shape from weight mask with limitations:
...@@ -825,6 +848,7 @@ def linear_mask(module_masks, mask, shape): ...@@ -825,6 +848,7 @@ def linear_mask(module_masks, mask, shape):
module_masks.set_input_mask(input_cmask) module_masks.set_input_mask(input_cmask)
return input_cmask, None return input_cmask, None
def conv2d_mask(module_masks, mask): def conv2d_mask(module_masks, mask):
""" """
Infer input and output shape from weight mask Infer input and output shape from weight mask
...@@ -863,8 +887,9 @@ def conv2d_mask(module_masks, mask): ...@@ -863,8 +887,9 @@ def conv2d_mask(module_masks, mask):
weight_mask = mask['weight'] weight_mask = mask['weight']
sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3) sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3)
index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=True)[0] index = torch.nonzero(weight_mask.abs().sum(
if len(index) == weight_mask.shape[dim]: # full mask sum_idx) != 0, as_tuple=True)[0]
if len(index) == weight_mask.shape[dim]: # full mask
index = None index = None
if index is None: if index is None:
...@@ -882,7 +907,8 @@ def conv2d_mask(module_masks, mask): ...@@ -882,7 +907,8 @@ def conv2d_mask(module_masks, mask):
bias_cmask.add_index_mask(dim=0, index=bias_index) bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask return index, weight_cmask, bias_cmask
index, weight_cmask, bias_cmask = convert_to_coarse_mask(mask, dim=conv_prune_dim) index, weight_cmask, bias_cmask = convert_to_coarse_mask(
mask, dim=conv_prune_dim)
if index is None: if index is None:
# TODO: fine grained mask speedup # TODO: fine grained mask speedup
...@@ -910,7 +936,8 @@ def conv2d_mask(module_masks, mask): ...@@ -910,7 +936,8 @@ def conv2d_mask(module_masks, mask):
module_masks.set_input_mask(io_cmask) module_masks.set_input_mask(io_cmask)
else: else:
assert module_masks.input_mask == io_cmask assert module_masks.input_mask == io_cmask
return module_masks.input_mask, None return module_masks.input_mask, None
def conv2d_inshape(module_masks, mask): def conv2d_inshape(module_masks, mask):
""" """
...@@ -972,7 +999,8 @@ def conv2d_outshape(module_masks, mask): ...@@ -972,7 +999,8 @@ def conv2d_outshape(module_masks, mask):
# mask conflict should be solved by fix_mask_conflict before speedup # mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions # mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d # since they could be passed by linear or conv2d
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1]) assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
weight_cmask = CoarseMask(num_dim=4) weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
...@@ -988,3 +1016,74 @@ def conv2d_outshape(module_masks, mask): ...@@ -988,3 +1016,74 @@ def conv2d_outshape(module_masks, mask):
module_masks.input_mask = mask module_masks.input_mask = mask
return mask return mask
return None return None
def convtranspose2d_mask(module_masks, mask):
# TODO support the Convtranspose2d Pruning for the L1FilterPruner
raise Exception(
"Current Filter pruner cannot prune the ConvTranspose2d, will support pruning ConvTranspose2d later")
def convtranspose2d_inshape(module_masks, mask):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
else:
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert module_masks.input_mask == mask
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
def convtranspose2d_outshape(module_masks, mask):
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
if module_masks.output_mask is None:
module_masks.output_mask = mask
else:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
weight_cmask = CoarseMask(num_dim=4)
# Note the memory layout of Convtranspose2d is C_in, C_out, k1, k2
weight_cmask.add_index_mask(dim=1, index=mask.mask_index[1])
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
...@@ -9,6 +9,7 @@ from .utils import get_module_by_name ...@@ -9,6 +9,7 @@ from .utils import get_module_by_name
# logging.basicConfig(level = logging.DEBUG) # logging.basicConfig(level = logging.DEBUG)
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
""" """
MaskConflict fix the mask conflict for the channel dependencies MaskConflict fix the mask conflict for the channel dependencies
...@@ -50,6 +51,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): ...@@ -50,6 +51,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks = padding_cat_mask.fix_mask() masks = padding_cat_mask.fix_mask()
return masks, fix_channel_mask.conv_prune_dim return masks, fix_channel_mask.conv_prune_dim
class MaskFix: class MaskFix:
def __init__(self, masks, model=None, dummy_input=None, traced=None): def __init__(self, masks, model=None, dummy_input=None, traced=None):
# check if the parameters are valid # check if the parameters are valid
...@@ -74,6 +76,7 @@ class MaskFix: ...@@ -74,6 +76,7 @@ class MaskFix:
""" """
torch.save(self.masks, path) torch.save(self.masks, path)
class CatMaskPadding(MaskFix): class CatMaskPadding(MaskFix):
def __init__(self, masks, model, dummy_input=None, traced=None): def __init__(self, masks, model, dummy_input=None, traced=None):
""" """
...@@ -100,7 +103,8 @@ class CatMaskPadding(MaskFix): ...@@ -100,7 +103,8 @@ class CatMaskPadding(MaskFix):
super(CatMaskPadding, self).__init__(masks, model, dummy_input, traced) super(CatMaskPadding, self).__init__(masks, model, dummy_input, traced)
def fix_mask(self): def fix_mask(self):
cat_padding_depen = CatPaddingDependency(self.model, self.dummy_input, self.traced) cat_padding_depen = CatPaddingDependency(
self.model, self.dummy_input, self.traced)
name_to_module = {} name_to_module = {}
for name, module in self.model.named_modules(): for name, module in self.model.named_modules():
name_to_module[name] = module name_to_module[name] = module
...@@ -131,11 +135,10 @@ class CatMaskPadding(MaskFix): ...@@ -131,11 +135,10 @@ class CatMaskPadding(MaskFix):
# module.bias may be None # module.bias may be None
b_shape = module.bias.data.size() b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device) b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight':w_mask, 'bias':b_mask} self.masks[layer] = {'weight': w_mask, 'bias': b_mask}
return self.masks return self.masks
class GroupMaskConflict(MaskFix): class GroupMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None): def __init__(self, masks, model=None, dummy_input=None, traced=None):
""" """
...@@ -154,8 +157,8 @@ class GroupMaskConflict(MaskFix): ...@@ -154,8 +157,8 @@ class GroupMaskConflict(MaskFix):
the traced model of the target model, is this parameter is not None, the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph. we donnot use the model and dummpy_input to get the trace graph.
""" """
super(GroupMaskConflict, self).__init__(masks, model, dummy_input, traced) super(GroupMaskConflict, self).__init__(
masks, model, dummy_input, traced)
def fix_mask(self): def fix_mask(self):
""" """
...@@ -163,7 +166,8 @@ class GroupMaskConflict(MaskFix): ...@@ -163,7 +166,8 @@ class GroupMaskConflict(MaskFix):
has group dependencies. This function should be called before the has group dependencies. This function should be called before the
mask inference of the 'speedup' module. mask inference of the 'speedup' module.
""" """
group_depen = GroupDependency(self.model, self.dummy_input, self.traced) group_depen = GroupDependency(
self.model, self.dummy_input, self.traced)
depens = group_depen.dependency depens = group_depen.dependency
_logger.info(depens) _logger.info(depens)
for layername in depens: for layername in depens:
...@@ -174,8 +178,10 @@ class GroupMaskConflict(MaskFix): ...@@ -174,8 +178,10 @@ class GroupMaskConflict(MaskFix):
w_mask = self.masks[layername]['weight'] w_mask = self.masks[layername]['weight']
shape = w_mask.size() shape = w_mask.size()
count = np.prod(shape[1:]) count = np.prod(shape[1:])
all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist() all_ones = (w_mask.flatten(1).sum(-1) ==
all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist() count).nonzero().squeeze(1).tolist()
all_zeros = (w_mask.flatten(1).sum(-1) ==
0).nonzero().squeeze(1).tolist()
if len(all_ones) + len(all_zeros) < w_mask.size(0): if len(all_ones) + len(all_zeros) < w_mask.size(0):
# In fine-grained pruning, skip this layer # In fine-grained pruning, skip this layer
_logger.info('Layers %s using fine-grained pruning', layername) _logger.info('Layers %s using fine-grained pruning', layername)
...@@ -190,7 +196,8 @@ class GroupMaskConflict(MaskFix): ...@@ -190,7 +196,8 @@ class GroupMaskConflict(MaskFix):
for i in range(group): for i in range(group):
_start = step * i _start = step * i
_end = step * (i+1) _end = step * (i+1)
_tmp_list = list(filter(lambda x: _start <= x and x < _end, all_zeros)) _tmp_list = list(
filter(lambda x: _start <= x and x < _end, all_zeros))
group_masked.append(_tmp_list) group_masked.append(_tmp_list)
mini_masked = min([len(x) for x in group_masked]) mini_masked = min([len(x) for x in group_masked])
for gm in group_masked: for gm in group_masked:
...@@ -198,13 +205,13 @@ class GroupMaskConflict(MaskFix): ...@@ -198,13 +205,13 @@ class GroupMaskConflict(MaskFix):
# To keep the output channel number still being divisible to # To keep the output channel number still being divisible to
# groups, we set the masks of following filters to be zero. # groups, we set the masks of following filters to be zero.
pos = gm[i] pos = gm[i]
self.masks[layername]['weight'][pos] = torch.ones(shape[1:]) self.masks[layername]['weight'][pos] = torch.ones(
if hasattr(self.masks[layername], 'bias'): shape[1:])
if 'bias' in self.masks[layername] and self.masks[layername]['bias'] is not None:
self.masks[layername]['bias'][pos] = 1 self.masks[layername]['bias'][pos] = 1
return self.masks return self.masks
class ChannelMaskConflict(MaskFix): class ChannelMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None): def __init__(self, masks, model=None, dummy_input=None, traced=None):
""" """
...@@ -223,7 +230,8 @@ class ChannelMaskConflict(MaskFix): ...@@ -223,7 +230,8 @@ class ChannelMaskConflict(MaskFix):
the traced graph of the target model, is this parameter is not None, the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph. we donnot use the model and dummpy_input to get the trace graph.
""" """
super(ChannelMaskConflict, self).__init__(masks, model, dummy_input, traced) super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model) self.conv_prune_dim = detect_mask_prune_dim(masks, model)
_logger.info('detected conv prune dim: %s', self.conv_prune_dim) _logger.info('detected conv prune dim: %s', self.conv_prune_dim)
...@@ -235,9 +243,11 @@ class ChannelMaskConflict(MaskFix): ...@@ -235,9 +243,11 @@ class ChannelMaskConflict(MaskFix):
are supported. are supported.
""" """
if self.conv_prune_dim == 0: if self.conv_prune_dim == 0:
channel_depen = ChannelDependency(self.model, self.dummy_input, self.traced) channel_depen = ChannelDependency(
self.model, self.dummy_input, self.traced)
else: else:
channel_depen = InputChannelDependency(self.model, self.dummy_input, self.traced) channel_depen = InputChannelDependency(
self.model, self.dummy_input, self.traced)
depen_sets = channel_depen.dependency_sets depen_sets = channel_depen.dependency_sets
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3) sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)
for dset in depen_sets: for dset in depen_sets:
...@@ -262,17 +272,29 @@ class ChannelMaskConflict(MaskFix): ...@@ -262,17 +272,29 @@ class ChannelMaskConflict(MaskFix):
channel_masks.append((mask.abs().sum(0) != 0).int()) channel_masks.append((mask.abs().sum(0) != 0).int())
elif type(m).__name__ == 'BatchNorm2d': elif type(m).__name__ == 'BatchNorm2d':
channel_masks.append(mask.int()) channel_masks.append(mask.int())
elif type(m).__name__ == 'ConvTranspose2d':
# convtranspose have difference memory layout, so that we need create
# a tmp_sum_idx for conv_transpose
tmp_sum_idx = (
0, 2, 3) if self.conv_prune_dim == 0 else (1, 2, 3)
channel_mask = (mask.abs().sum(tmp_sum_idx) != 0).int()
channel_masks.append(channel_mask)
if (channel_mask.sum() * (mask.numel() / mask.shape[1-self.conv_prune_dim])).item() != (mask > 0).sum().item():
fine_grained = True
else: else:
raise RuntimeError(f'unsupported module type: {type(m).__name__}') raise RuntimeError(
f'unsupported module type: {type(m).__name__}')
else: else:
# no mask means not pruned, equivlent to full masks # no mask means not pruned, equivlent to full masks
channel_masks.append(None) channel_masks.append(None)
if fine_grained: if fine_grained:
_logger.info('fine-grained mask detected, skip solving conflict for this set: %s', dset) _logger.info(
'fine-grained mask detected, skip solving conflict for this set: %s', dset)
continue continue
if all(x is None for x in channel_masks): if all(x is None for x in channel_masks):
continue continue
num_channels_list = [len(x) for x in channel_masks if x is not None] num_channels_list = [len(x)
for x in channel_masks if x is not None]
# number of channels in same set should be identical # number of channels in same set should be identical
assert len(set(num_channels_list)) == 1 assert len(set(num_channels_list)) == 1
num_channels = num_channels_list[0] num_channels = num_channels_list[0]
...@@ -284,7 +306,8 @@ class ChannelMaskConflict(MaskFix): ...@@ -284,7 +306,8 @@ class ChannelMaskConflict(MaskFix):
# merge masks with 'or' # merge masks with 'or'
merged_channel_mask = channel_masks[0].clone() merged_channel_mask = channel_masks[0].clone()
for i in range(1, len(channel_masks)): for i in range(1, len(channel_masks)):
merged_channel_mask = ((merged_channel_mask + channel_masks[i]) != 0).int() merged_channel_mask = (
(merged_channel_mask + channel_masks[i]) != 0).int()
merged_index = torch.nonzero(merged_channel_mask, as_tuple=True)[0] merged_index = torch.nonzero(merged_channel_mask, as_tuple=True)[0]
...@@ -305,16 +328,19 @@ class ChannelMaskConflict(MaskFix): ...@@ -305,16 +328,19 @@ class ChannelMaskConflict(MaskFix):
elif type(m).__name__ == 'BatchNorm2d': elif type(m).__name__ == 'BatchNorm2d':
new_mask = merged_index.type_as(orig_mask) new_mask = merged_index.type_as(orig_mask)
else: else:
raise RuntimeError(f'unsupported module type: {type(m).__name__}') raise RuntimeError(
f'unsupported module type: {type(m).__name__}')
self.masks[name]['weight'] = new_mask self.masks[name]['weight'] = new_mask
if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None: if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None:
if type(m).__name__ == 'Conv2d': if type(m).__name__ == 'Conv2d':
assert self.conv_prune_dim == 0 assert self.conv_prune_dim == 0
self.masks[name]['bias'] = merged_channel_mask.type_as(self.masks[name]['bias']) self.masks[name]['bias'] = merged_channel_mask.type_as(
self.masks[name]['bias'])
return self.masks return self.masks
def detect_mask_prune_dim(masks, model): def detect_mask_prune_dim(masks, model):
""" """
Detect how the masks of convolutional layers are pruned. Detect how the masks of convolutional layers are pruned.
...@@ -358,7 +384,8 @@ def detect_mask_prune_dim(masks, model): ...@@ -358,7 +384,8 @@ def detect_mask_prune_dim(masks, model):
_logger.warning('no multi-dimension masks found.') _logger.warning('no multi-dimension masks found.')
return 0 return 0
dim0_sparsity, dim1_sparsity = 1. - dim0_preserved / dim0_num, 1. - dim1_preserved / dim1_num dim0_sparsity, dim1_sparsity = 1. - dim0_preserved / \
dim0_num, 1. - dim1_preserved / dim1_num
_logger.info('dim0 sparsity: %f', dim0_sparsity) _logger.info('dim0 sparsity: %f', dim0_sparsity)
_logger.info('dim1 sparsity: %f', dim1_sparsity) _logger.info('dim1 sparsity: %f', dim1_sparsity)
......
...@@ -4,13 +4,16 @@ ...@@ -4,13 +4,16 @@
import csv import csv
import logging import logging
__all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency', 'InputChannelDependency'] __all__ = ['ChannelDependency', 'GroupDependency',
'CatPaddingDependency', 'InputChannelDependency']
CONV_TYPE = 'aten::_convolution' CONV_TYPE = 'aten::_convolution'
ADD_TYPES = ['aten::add', 'aten::add_'] ADD_TYPES = ['aten::add', 'aten::add_']
CAT_TYPE = 'aten::cat' CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency') logger = logging.getLogger('Shape_Dependency')
RESHAPE_OPS = [CAT_TYPE, 'aten::view', 'aten::reshape', 'aten::flatten', 'aten::mean'] RESHAPE_OPS = [CAT_TYPE, 'aten::view',
'aten::reshape', 'aten::flatten', 'aten::mean']
class Dependency: class Dependency:
def __init__(self, model=None, dummy_input=None, traced_model=None): def __init__(self, model=None, dummy_input=None, traced_model=None):
...@@ -34,6 +37,7 @@ class Dependency: ...@@ -34,6 +37,7 @@ class Dependency:
def export(self, filepath): def export(self, filepath):
raise NotImplementedError raise NotImplementedError
class ChannelDependency(Dependency): class ChannelDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None): def __init__(self, model=None, dummy_input=None, traced_model=None):
""" """
...@@ -50,7 +54,8 @@ class ChannelDependency(Dependency): ...@@ -50,7 +54,8 @@ class ChannelDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot if we alreay has the traced graph of the target model, we donnot
need to trace the model again. need to trace the model again.
""" """
super(ChannelDependency, self).__init__(model, dummy_input, traced_model) super(ChannelDependency, self).__init__(
model, dummy_input, traced_model)
def _get_parent_layers(self, node): def _get_parent_layers(self, node):
""" """
...@@ -71,7 +76,7 @@ class ChannelDependency(Dependency): ...@@ -71,7 +76,7 @@ class ChannelDependency(Dependency):
queue.append(node) queue.append(node)
while queue: while queue:
curnode = queue.pop(0) curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear': if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv # find the first met conv
parent_layers.append(curnode.name) parent_layers.append(curnode.name)
continue continue
...@@ -119,7 +124,6 @@ class ChannelDependency(Dependency): ...@@ -119,7 +124,6 @@ class ChannelDependency(Dependency):
for _node in dependency_set: for _node in dependency_set:
self.dependency[_node] = dependency_set self.dependency[_node] = dependency_set
def export(self, filepath): def export(self, filepath):
""" """
export the channel dependencies as a csv file. export the channel dependencies as a csv file.
...@@ -185,6 +189,7 @@ class ChannelDependency(Dependency): ...@@ -185,6 +189,7 @@ class ChannelDependency(Dependency):
d_sets.append(tmp_set) d_sets.append(tmp_set)
return d_sets return d_sets
def reshape_break_channel_dependency(op_node): def reshape_break_channel_dependency(op_node):
""" """
The reshape operations such as (reshape, view, flatten) may break The reshape operations such as (reshape, view, flatten) may break
...@@ -213,6 +218,7 @@ def reshape_break_channel_dependency(op_node): ...@@ -213,6 +218,7 @@ def reshape_break_channel_dependency(op_node):
out_channel = out_shape[1] out_channel = out_shape[1]
return in_channel != out_channel return in_channel != out_channel
class InputChannelDependency(ChannelDependency): class InputChannelDependency(ChannelDependency):
""" """
Some pruners may prune the input channel of the convolutional Some pruners may prune the input channel of the convolutional
...@@ -242,7 +248,8 @@ class InputChannelDependency(ChannelDependency): ...@@ -242,7 +248,8 @@ class InputChannelDependency(ChannelDependency):
if we alreay has the traced graph of the target model, we donnot if we alreay has the traced graph of the target model, we donnot
need to trace the model again. need to trace the model again.
""" """
super(InputChannelDependency, self).__init__(model, dummy_input, traced_model) super(InputChannelDependency, self).__init__(
model, dummy_input, traced_model)
def _get_following_convs(self, tensor): def _get_following_convs(self, tensor):
queue = [] queue = []
...@@ -250,14 +257,14 @@ class InputChannelDependency(ChannelDependency): ...@@ -250,14 +257,14 @@ class InputChannelDependency(ChannelDependency):
queue.extend(self.graph.input_to_node[tensor]) queue.extend(self.graph.input_to_node[tensor])
while queue: while queue:
curnode = queue.pop(0) curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear': if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv # find the first met conv
key_layers.append(curnode.name) key_layers.append(curnode.name)
continue continue
elif curnode.op_type in RESHAPE_OPS: elif curnode.op_type in RESHAPE_OPS:
# check if the reshape operation will break the channel dependency # check if the reshape operation will break the channel dependency
if reshape_break_channel_dependency(curnode): if reshape_break_channel_dependency(curnode):
# reshape operations also breaks the dependency relationship # reshape operations also breaks the dependency relationship
continue continue
successors = self.graph.find_successors(curnode.unique_name) successors = self.graph.find_successors(curnode.unique_name)
successors = [self.graph.name_to_node[name] for name in successors] successors = [self.graph.name_to_node[name] for name in successors]
...@@ -290,7 +297,8 @@ class InputChannelDependency(ChannelDependency): ...@@ -290,7 +297,8 @@ class InputChannelDependency(ChannelDependency):
class CatPaddingDependency(ChannelDependency): class CatPaddingDependency(ChannelDependency):
def __init__(self, model=None, dummy_input=None, traced_model=None): def __init__(self, model=None, dummy_input=None, traced_model=None):
super(CatPaddingDependency, self).__init__(model, dummy_input, traced_model) super(CatPaddingDependency, self).__init__(
model, dummy_input, traced_model)
def build_dependency(self): def build_dependency(self):
""" """
...@@ -347,6 +355,7 @@ class CatPaddingDependency(ChannelDependency): ...@@ -347,6 +355,7 @@ class CatPaddingDependency(ChannelDependency):
row.extend(list(layers)) row.extend(list(layers))
csv_w.writerow(row) csv_w.writerow(row)
class GroupDependency(Dependency): class GroupDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None): def __init__(self, model=None, dummy_input=None, traced_model=None):
""" """
...@@ -388,7 +397,7 @@ class GroupDependency(Dependency): ...@@ -388,7 +397,7 @@ class GroupDependency(Dependency):
queue = predeessors queue = predeessors
while queue: while queue:
curnode = queue.pop(0) curnode = queue.pop(0)
if curnode.op_type == 'Conv2d': if curnode.op_type == 'Conv2d' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv # find the first met conv
parent_layers.append(curnode.name) parent_layers.append(curnode.name)
continue continue
...@@ -412,7 +421,8 @@ class GroupDependency(Dependency): ...@@ -412,7 +421,8 @@ class GroupDependency(Dependency):
group : int group : int
the number of the groups of the target conv layer. the number of the groups of the target conv layer.
""" """
cpp_conv = list(filter(lambda x: x.kind() == CONV_TYPE, node_group.node_cpps)) cpp_conv = list(filter(lambda x: x.kind() ==
CONV_TYPE, node_group.node_cpps))
assert len(cpp_conv) == 1 assert len(cpp_conv) == 1
cpp_conv = cpp_conv[0] cpp_conv = cpp_conv[0]
inputs = list(cpp_conv.inputs()) inputs = list(cpp_conv.inputs())
...@@ -442,12 +452,14 @@ class GroupDependency(Dependency): ...@@ -442,12 +452,14 @@ class GroupDependency(Dependency):
filters should be divisible to. filters should be divisible to.
""" """
for node in self.graph.nodes_py.nodes_op: for node in self.graph.nodes_py.nodes_op:
if node.op_type == 'Conv2d': if node.op_type == 'Conv2d' or node.op_type == 'ConvTranspose2d':
group = self._get_conv_groups(node) group = self._get_conv_groups(node)
if node.name in self.dependency: if node.name in self.dependency:
# the conv layer whose group is larger than 1 will require that # the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group. # it's number of output channel to be divisible by the number of group.
self.dependency[node.name] = max(self.dependency[node.name], group) self.dependency[node.name] = max(
self.dependency[node.name], group)
else: else:
self.dependency[node.name] = group self.dependency[node.name] = group
if group > 1: if group > 1:
...@@ -456,7 +468,8 @@ class GroupDependency(Dependency): ...@@ -456,7 +468,8 @@ class GroupDependency(Dependency):
parent_convs = self._get_parent_convs(node) parent_convs = self._get_parent_convs(node)
for parent in parent_convs: for parent in parent_convs:
if parent in self.dependency: if parent in self.dependency:
self.dependency[parent] = max(self.dependency[parent], group) self.dependency[parent] = max(
self.dependency[parent], group)
else: else:
self.dependency[parent] = group self.dependency[parent] = group
return self.dependency return self.dependency
...@@ -484,6 +497,7 @@ class GroupDependency(Dependency): ...@@ -484,6 +497,7 @@ class GroupDependency(Dependency):
for name in self.dependency: for name in self.dependency:
group = self.dependency[name] group = self.dependency[name]
csv_w.writerow([name, group]) csv_w.writerow([name, group])
@property @property
def dependency_sets(self): def dependency_sets(self):
return self.dependency return self.dependency
...@@ -30,13 +30,17 @@ RELATIVE_THRESHOLD = 0.01 ...@@ -30,13 +30,17 @@ RELATIVE_THRESHOLD = 0.01
# an absolute threshold to determine whether the final result is correct. # an absolute threshold to determine whether the final result is correct.
# The error should meet the RELATIVE_THREHOLD or the ABSOLUTE_THRESHOLD. # The error should meet the RELATIVE_THREHOLD or the ABSOLUTE_THRESHOLD.
ABSOLUTE_THRESHOLD = 0.0001 ABSOLUTE_THRESHOLD = 0.0001
class BackboneModel1(nn.Module): class BackboneModel1(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1, 1) self.conv1 = nn.Conv2d(1, 1, 1, 1)
def forward(self, x): def forward(self, x):
return self.conv1(x) return self.conv1(x)
class BackboneModel2(torch.nn.Module): class BackboneModel2(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -53,32 +57,58 @@ class BackboneModel2(torch.nn.Module): ...@@ -53,32 +57,58 @@ class BackboneModel2(torch.nn.Module):
x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return x return x
class BigModel(torch.nn.Module): class BigModel(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.backbone1 = BackboneModel1() self.backbone1 = BackboneModel1()
self.backbone2 = BackboneModel2() self.backbone2 = BackboneModel2()
self.fc3 = nn.Sequential( self.fc3 = nn.Sequential(
nn.Linear(10, 10), nn.Linear(10, 10),
nn.BatchNorm1d(10), nn.BatchNorm1d(10),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(10, 2) nn.Linear(10, 2)
) )
def forward(self, x): def forward(self, x):
x = self.backbone1(x) x = self.backbone1(x)
x = self.backbone2(x) x = self.backbone2(x)
x = self.fc3(x) x = self.fc3(x)
return x return x
class TransposeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 20, 5)
self.conv2 = nn.ConvTranspose2d(20, 50, 5, groups=2)
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
self.fc1 = nn.Linear(8 * 8 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
# x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
# x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
dummy_input = torch.randn(2, 1, 28, 28) dummy_input = torch.randn(2, 1, 28, 28)
SPARSITY = 0.5 SPARSITY = 0.5
MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth' MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth'
def prune_model_l1(model): def prune_model_l1(model):
config_list = [{ config_list = [{
'sparsity': SPARSITY, 'sparsity': SPARSITY,
...@@ -88,6 +118,7 @@ def prune_model_l1(model): ...@@ -88,6 +118,7 @@ def prune_model_l1(model):
pruner.compress() pruner.compress()
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE) pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
def generate_random_sparsity(model): def generate_random_sparsity(model):
cfg_list = [] cfg_list = []
for name, module in model.named_modules(): for name, module in model.named_modules():
...@@ -97,18 +128,20 @@ def generate_random_sparsity(model): ...@@ -97,18 +128,20 @@ def generate_random_sparsity(model):
'sparsity': sparsity}) 'sparsity': sparsity})
return cfg_list return cfg_list
def zero_bn_bias(model): def zero_bn_bias(model):
with torch.no_grad(): with torch.no_grad():
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d) \ if isinstance(module, nn.BatchNorm2d) \
or isinstance(module, nn.BatchNorm3d) \ or isinstance(module, nn.BatchNorm3d) \
or isinstance(module, nn.BatchNorm1d): or isinstance(module, nn.BatchNorm1d):
shape = module.bias.data.size() shape = module.bias.data.size()
device = module.bias.device device = module.bias.device
module.bias.data = torch.zeros(shape).to(device) module.bias.data = torch.zeros(shape).to(device)
shape = module.running_mean.data.size() shape = module.running_mean.data.size()
module.running_mean = torch.zeros(shape).to(device) module.running_mean = torch.zeros(shape).to(device)
class L1ChannelMasker(WeightMasker): class L1ChannelMasker(WeightMasker):
def __init__(self, model, pruner): def __init__(self, model, pruner):
self.model = model self.model = model
...@@ -143,21 +176,27 @@ class L1ChannelMasker(WeightMasker): ...@@ -143,21 +176,27 @@ class L1ChannelMasker(WeightMasker):
w_abs = weight.abs() w_abs = weight.abs()
if wrapper.type == 'Conv2d': if wrapper.type == 'Conv2d':
w_abs_structured = w_abs.sum((0, 2, 3)) w_abs_structured = w_abs.sum((0, 2, 3))
threshold = torch.topk(w_abs_structured, num_prune, largest=False)[0].max() threshold = torch.topk(
mask_weight = torch.gt(w_abs_structured, threshold)[None, :, None, None].expand_as(weight).type_as(weight) w_abs_structured, num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[
None, :, None, None].expand_as(weight).type_as(weight)
return {'weight_mask': mask_weight.detach()} return {'weight_mask': mask_weight.detach()}
else: else:
# Linear # Linear
assert wrapper.type == 'Linear' assert wrapper.type == 'Linear'
w_abs_structured = w_abs.sum((0)) w_abs_structured = w_abs.sum((0))
threshold = torch.topk(w_abs_structured, num_prune, largest=False)[0].max() threshold = torch.topk(
mask_weight = torch.gt(w_abs_structured, threshold)[None, :].expand_as(weight).type_as(weight) w_abs_structured, num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[
None, :].expand_as(weight).type_as(weight)
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias} return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class L1ChannelPruner(_StructuredFilterPruner): class L1ChannelPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None): def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer, super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer,
dependency_aware=dependency_aware, dummy_input=dummy_input) dependency_aware=dependency_aware, dummy_input=dummy_input)
def validate_config(self, model, config_list): def validate_config(self, model, config_list):
pass pass
...@@ -177,6 +216,7 @@ def channel_prune(model): ...@@ -177,6 +216,7 @@ def channel_prune(model):
pruner.compress() pruner.compress()
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE) pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
class SpeedupTestCase(TestCase): class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self): def test_speedup_vgg16(self):
prune_model_l1(vgg16()) prune_model_l1(vgg16())
...@@ -187,8 +227,10 @@ class SpeedupTestCase(TestCase): ...@@ -187,8 +227,10 @@ class SpeedupTestCase(TestCase):
orig_model = vgg16() orig_model = vgg16()
assert model.training assert model.training
assert model.features[2].out_channels == int(orig_model.features[2].out_channels * SPARSITY) assert model.features[2].out_channels == int(
assert model.classifier[0].in_features == int(orig_model.classifier[0].in_features * SPARSITY) orig_model.features[2].out_channels * SPARSITY)
assert model.classifier[0].in_features == int(
orig_model.classifier[0].in_features * SPARSITY)
def test_speedup_bigmodel(self): def test_speedup_bigmodel(self):
prune_model_l1(BigModel()) prune_model_l1(BigModel())
...@@ -205,23 +247,55 @@ class SpeedupTestCase(TestCase): ...@@ -205,23 +247,55 @@ class SpeedupTestCase(TestCase):
model.eval() model.eval()
speedup_out = model(dummy_input) speedup_out = model(dummy_input)
if not torch.allclose(mask_out, speedup_out, atol=1e-07): if not torch.allclose(mask_out, speedup_out, atol=1e-07):
print('input:', dummy_input.size(), torch.abs(dummy_input).sum((2,3))) print('input:', dummy_input.size(),
torch.abs(dummy_input).sum((2, 3)))
print('mask_out:', mask_out) print('mask_out:', mask_out)
print('speedup_out:', speedup_out) print('speedup_out:', speedup_out)
raise RuntimeError('model speedup inference result is incorrect!') raise RuntimeError('model speedup inference result is incorrect!')
orig_model = BigModel() orig_model = BigModel()
assert model.backbone2.conv1.out_channels == int(orig_model.backbone2.conv1.out_channels * SPARSITY) assert model.backbone2.conv1.out_channels == int(
assert model.backbone2.conv2.in_channels == int(orig_model.backbone2.conv2.in_channels * SPARSITY) orig_model.backbone2.conv1.out_channels * SPARSITY)
assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY) assert model.backbone2.conv2.in_channels == int(
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY) orig_model.backbone2.conv2.in_channels * SPARSITY)
assert model.backbone2.conv2.out_channels == int(
orig_model.backbone2.conv2.out_channels * SPARSITY)
assert model.backbone2.fc1.in_features == int(
orig_model.backbone2.fc1.in_features * SPARSITY)
def test_convtranspose_model(self):
ori_model = TransposeModel()
dummy_input = torch.rand(1, 3, 8, 8)
config_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}]
pruner = L1FilterPruner(ori_model, config_list)
pruner.compress()
ori_model(dummy_input)
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
new_model = TransposeModel()
state_dict = torch.load(MODEL_FILE)
new_model.load_state_dict(state_dict)
ms = ModelSpeedup(new_model, dummy_input, MASK_FILE)
ms.speedup_model()
zero_bn_bias(ori_model)
zero_bn_bias(new_model)
ori_out = ori_model(dummy_input)
new_out = new_model(dummy_input)
ori_sum = torch.sum(ori_out)
speeded_sum = torch.sum(new_out)
print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(ori_sum, speeded_sum))
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
# FIXME: This test case might fail randomly, no idea why # FIXME: This test case might fail randomly, no idea why
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282 # Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
def test_speedup_integration(self): def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3', 'resnet50']: for model_name in ['resnet18', 'squeezenet1_1',
'mobilenet_v2', 'densenet121',
# 'inception_v3' inception is too large and may fail the pipeline
'densenet169', 'resnet50']:
kwargs = { kwargs = {
'pretrained': True 'pretrained': True
} }
...@@ -235,7 +309,7 @@ class SpeedupTestCase(TestCase): ...@@ -235,7 +309,7 @@ class SpeedupTestCase(TestCase):
Model = getattr(models, model_name) Model = getattr(models, model_name)
net = Model(**kwargs).to(device) net = Model(**kwargs).to(device)
speedup_model = Model(**kwargs).to(device) speedup_model = Model(**kwargs).to(device)
net.eval() # this line is necessary net.eval() # this line is necessary
speedup_model.eval() speedup_model.eval()
# random generate the prune config for the pruner # random generate the prune config for the pruner
cfgs = generate_random_sparsity(net) cfgs = generate_random_sparsity(net)
...@@ -258,8 +332,10 @@ class SpeedupTestCase(TestCase): ...@@ -258,8 +332,10 @@ class SpeedupTestCase(TestCase):
speeded_out = speedup_model(data) speeded_out = speedup_model(data)
ori_sum = torch.sum(ori_out).item() ori_sum = torch.sum(ori_out).item()
speeded_sum = torch.sum(speeded_out).item() speeded_sum = torch.sum(speeded_out).item()
print('Sum of the output of %s (before speedup):'%model_name, ori_sum) print('Sum of the output of %s (before speedup):' %
print('Sum of the output of %s (after speedup):'%model_name, speeded_sum) model_name, ori_sum)
print('Sum of the output of %s (after speedup):' %
model_name, speeded_sum)
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
...@@ -296,5 +372,6 @@ class SpeedupTestCase(TestCase): ...@@ -296,5 +372,6 @@ class SpeedupTestCase(TestCase):
os.remove(MODEL_FILE) os.remove(MODEL_FILE)
os.remove(MASK_FILE) os.remove(MASK_FILE)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment