Unverified Commit 4e21e721 authored by Cjkkkk's avatar Cjkkkk Committed by GitHub
Browse files

update level pruner to adapt to pruner dataparallel refactor (#1993)

parent d452a166
......@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner
from nni.compression.torch import ActivationMeanRankFilterPruner
from models.cifar10.vgg import VGG
......@@ -96,7 +96,7 @@ def main():
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list)
pruner = ActivationMeanRankFilterPruner(model, configure_list)
model = pruner.compress()
if args.parallel:
if torch.cuda.device_count() > 1:
......
......@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
"""
super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
......@@ -48,16 +48,23 @@ class ActivationRankFilterPruner(Pruner):
"""
Compress the model, register a hook for collecting activations.
"""
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))
layer.module.register_forward_hook(_hook)
wrapper.module.register_forward_hook(_hook)
self._wrap_model()
return self.bound_model
def get_mask(self, base_mask, activations, num_prune):
......@@ -103,7 +110,7 @@ class ActivationRankFilterPruner(Pruner):
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask
......
......@@ -89,7 +89,7 @@ class Compressor:
"""
if self.modules_wrapper is not None:
# already compressed
return
return self.bound_model
else:
self.modules_wrapper = []
......
......@@ -27,9 +27,9 @@ class LevelPruner(Pruner):
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config):
def calc_mask(self, layer, config, **kwargs):
"""
Calculate the mask of given layer
Parameters
......@@ -45,8 +45,9 @@ class LevelPruner(Pruner):
"""
weight = layer.module.weight.data
op_name = layer.name
if op_name not in self.mask_calculated_ops:
if_calculated = kwargs["if_calculated"]
if not if_calculated:
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
......@@ -54,12 +55,10 @@ class LevelPruner(Pruner):
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight}
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
else:
assert op_name in self.mask_dict, "op_name not in the mask_dict"
mask = self.mask_dict[op_name]
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask
else:
return None
class AGP_Pruner(Pruner):
......@@ -197,7 +196,7 @@ class SlimPruner(Pruner):
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config, **kwargs):
"""
......@@ -232,7 +231,7 @@ class SlimPruner(Pruner):
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask
class LotteryTicketPruner(Pruner):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment