Commit c7d58033 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix pruners for DataParallel support (#2003)

parent 4e21e721
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import FPGMPruner
......@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
......@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module):
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters
def print_conv_filter_sparsity(self):
if isinstance(self.conv1, nn.Conv2d):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
else:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data = self._get_conv_weight_sparsity(self.conv1.module)
conv2_data = self._get_conv_weight_sparsity(self.conv2.module)
print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2]))
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2]))
......
......@@ -246,7 +246,7 @@ class PrunerModuleWrapper(torch.nn.Module):
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None:
if mask is not None and 'bias' in mask:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
......@@ -565,4 +565,3 @@ def _check_weight(module):
return isinstance(module.weight.data, torch.Tensor)
except AttributeError:
return False
\ No newline at end of file
......@@ -83,17 +83,20 @@ class AGP_Pruner(Pruner):
super().__init__(model, config_list)
self.now_epoch = 0
self.if_init_list = {}
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
......@@ -101,24 +104,26 @@ class AGP_Pruner(Pruner):
"""
weight = layer.module.weight.data
op_name = layer.name
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
if_calculated = kwargs["if_calculated"]
if if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None
mask = {'weight': torch.ones(weight.shape).type_as(weight)}
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
w_abs = weight.abs()
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return new_mask
def compute_target_sparsity(self, config):
......@@ -164,9 +169,8 @@ class AGP_Pruner(Pruner):
if epoch > 0:
self.now_epoch = epoch
for k in self.if_init_list.keys():
self.if_init_list[k] = True
for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
class SlimPruner(Pruner):
"""
......
......@@ -27,7 +27,7 @@ class WeightRankFilterPruner(Pruner):
"""
super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
......@@ -69,7 +69,7 @@ class WeightRankFilterPruner(Pruner):
return mask
mask = self.get_mask(mask, weight, num_prune)
finally:
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask
......@@ -257,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
return x.sum()
def update_epoch(self, epoch):
self.mask_calculated_ops = set()
for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
......@@ -138,7 +138,6 @@ class CompressorTestCase(TestCase):
masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
......@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase):
assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment