Unverified Commit 4e21e721 authored by Cjkkkk's avatar Cjkkkk Committed by GitHub
Browse files

update level pruner to adapt to pruner dataparallel refactor (#1993)

parent d452a166
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from nni.compression.torch import L1FilterPruner from nni.compression.torch import ActivationMeanRankFilterPruner
from models.cifar10.vgg import VGG from models.cifar10.vgg import VGG
...@@ -96,7 +96,7 @@ def main(): ...@@ -96,7 +96,7 @@ def main():
# Prune model and test accuracy without fine tuning. # Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10) print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list) pruner = ActivationMeanRankFilterPruner(model, configure_list)
model = pruner.compress() model = pruner.compress()
if args.parallel: if args.parallel:
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
......
...@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.statistics_batch_num = statistics_batch_num self.statistics_batch_num = statistics_batch_num
self.collected_activation = {} self.collected_activation = {}
self.hooks = {} self.hooks = {}
...@@ -48,16 +48,23 @@ class ActivationRankFilterPruner(Pruner): ...@@ -48,16 +48,23 @@ class ActivationRankFilterPruner(Pruner):
""" """
Compress the model, register a hook for collecting activations. Compress the model, register a hook for collecting activations.
""" """
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
self._instrument_layer(layer, config) wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self.collected_activation[layer.name] = [] self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name): def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num: if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu())) self.collected_activation[name].append(self.activation(output.detach().cpu()))
layer.module.register_forward_hook(_hook) wrapper.module.register_forward_hook(_hook)
self._wrap_model()
return self.bound_model return self.bound_model
def get_mask(self, base_mask, activations, num_prune): def get_mask(self, base_mask, activations, num_prune):
...@@ -103,7 +110,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -103,7 +110,7 @@ class ActivationRankFilterPruner(Pruner):
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune) mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
finally: finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num: if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask return mask
......
...@@ -89,7 +89,7 @@ class Compressor: ...@@ -89,7 +89,7 @@ class Compressor:
""" """
if self.modules_wrapper is not None: if self.modules_wrapper is not None:
# already compressed # already compressed
return return self.bound_model
else: else:
self.modules_wrapper = [] self.modules_wrapper = []
......
...@@ -27,9 +27,9 @@ class LevelPruner(Pruner): ...@@ -27,9 +27,9 @@ class LevelPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer Calculate the mask of given layer
Parameters Parameters
...@@ -45,8 +45,9 @@ class LevelPruner(Pruner): ...@@ -45,8 +45,9 @@ class LevelPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name if_calculated = kwargs["if_calculated"]
if op_name not in self.mask_calculated_ops:
if not if_calculated:
w_abs = weight.abs() w_abs = weight.abs()
k = int(weight.numel() * config['sparsity']) k = int(weight.numel() * config['sparsity'])
if k == 0: if k == 0:
...@@ -54,12 +55,10 @@ class LevelPruner(Pruner): ...@@ -54,12 +55,10 @@ class LevelPruner(Pruner):
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight) mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight} mask = {'weight': mask_weight}
self.mask_dict.update({op_name: mask}) if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.mask_calculated_ops.add(op_name) return mask
else: else:
assert op_name in self.mask_dict, "op_name not in the mask_dict" return None
mask = self.mask_dict[op_name]
return mask
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
...@@ -197,7 +196,7 @@ class SlimPruner(Pruner): ...@@ -197,7 +196,7 @@ class SlimPruner(Pruner):
all_bn_weights = torch.cat(weight_list) all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity']) k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max() self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(False)) # pylint: disable=not-callable self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config, **kwargs): def calc_mask(self, layer, config, **kwargs):
""" """
...@@ -232,7 +231,7 @@ class SlimPruner(Pruner): ...@@ -232,7 +231,7 @@ class SlimPruner(Pruner):
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight) mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone() mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()} mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
if_calculated.copy_(torch.tensor(True)) # pylint: disable=not-callable if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return mask return mask
class LotteryTicketPruner(Pruner): class LotteryTicketPruner(Pruner):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment