"test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "69de8c4bdffee9e4ab94b78b570d2c8b1095ace4"
Commit c7d58033 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix pruners for DataParallel support (#2003)

parent 4e21e721
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from nni.compression.torch import FPGMPruner from nni.compression.torch import FPGMPruner
...@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner ...@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner
class Mnist(torch.nn.Module): class Mnist(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10) self.fc2 = nn.Linear(500, 10)
def forward(self, x): def forward(self, x):
x = F.relu(self.conv1(x)) x = F.relu(self.conv1(x))
...@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module): ...@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module):
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters return num_zero_filters, num_filters, float(num_zero_filters)/num_filters
def print_conv_filter_sparsity(self): def print_conv_filter_sparsity(self):
if isinstance(self.conv1, nn.Conv2d):
conv1_data = self._get_conv_weight_sparsity(self.conv1) conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2) conv2_data = self._get_conv_weight_sparsity(self.conv2)
else:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data = self._get_conv_weight_sparsity(self.conv1.module)
conv2_data = self._get_conv_weight_sparsity(self.conv2.module)
print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2])) print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2]))
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2])) print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2]))
......
...@@ -246,7 +246,7 @@ class PrunerModuleWrapper(torch.nn.Module): ...@@ -246,7 +246,7 @@ class PrunerModuleWrapper(torch.nn.Module):
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask) self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias # apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None: if mask is not None and 'bias' in mask:
self.bias_mask.copy_(mask['bias']) self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask) self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs) return self.module(*inputs)
...@@ -565,4 +565,3 @@ def _check_weight(module): ...@@ -565,4 +565,3 @@ def _check_weight(module):
return isinstance(module.weight.data, torch.Tensor) return isinstance(module.weight.data, torch.Tensor)
except AttributeError: except AttributeError:
return False return False
\ No newline at end of file
...@@ -83,17 +83,20 @@ class AGP_Pruner(Pruner): ...@@ -83,17 +83,20 @@ class AGP_Pruner(Pruner):
super().__init__(model, config_list) super().__init__(model, config_list)
self.now_epoch = 0 self.now_epoch = 0
self.if_init_list = {} self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters Parameters
---------- ----------
layer : LayerInfo layer : LayerInfo
the layer to instrument the compression operation the layer to instrument the compression operation
config : dict config : dict
layer's pruning config layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns Returns
------- -------
dict dict
...@@ -101,24 +104,26 @@ class AGP_Pruner(Pruner): ...@@ -101,24 +104,26 @@ class AGP_Pruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
start_epoch = config.get('start_epoch', 0) start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1) freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0: if_calculated = kwargs["if_calculated"]
mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)}) if if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None
mask = {'weight': torch.ones(weight.shape).type_as(weight)}
target_sparsity = self.compute_target_sparsity(config) target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity) k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask return mask
# if we want to generate new mask, we should update weigth first # if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask w_abs = weight.abs()
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)} new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask}) if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
return new_mask return new_mask
def compute_target_sparsity(self, config): def compute_target_sparsity(self, config):
...@@ -164,9 +169,8 @@ class AGP_Pruner(Pruner): ...@@ -164,9 +169,8 @@ class AGP_Pruner(Pruner):
if epoch > 0: if epoch > 0:
self.now_epoch = epoch self.now_epoch = epoch
for k in self.if_init_list.keys(): for wrapper in self.get_modules_wrapper():
self.if_init_list[k] = True wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
class SlimPruner(Pruner): class SlimPruner(Pruner):
""" """
......
...@@ -27,7 +27,7 @@ class WeightRankFilterPruner(Pruner): ...@@ -27,7 +27,7 @@ class WeightRankFilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def get_mask(self, base_mask, weight, num_prune): def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
...@@ -69,7 +69,7 @@ class WeightRankFilterPruner(Pruner): ...@@ -69,7 +69,7 @@ class WeightRankFilterPruner(Pruner):
return mask return mask
mask = self.get_mask(mask, weight, num_prune) mask = self.get_mask(mask, weight, num_prune)
finally: finally:
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask return mask
...@@ -257,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner): ...@@ -257,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
return x.sum() return x.sum()
def update_epoch(self, epoch): def update_epoch(self, epoch):
self.mask_calculated_ops = set() for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
...@@ -138,7 +138,6 @@ class CompressorTestCase(TestCase): ...@@ -138,7 +138,6 @@ class CompressorTestCase(TestCase):
masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0)) masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0)) masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
...@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase): ...@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase):
assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()]) model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy() masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0]) masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment