Commit fd551c86 authored by Tang Lang's avatar Tang Lang Committed by QuanluZhang
Browse files

fix builtin pruners bug (#1612)

* fix builtin pruners bug
parent d6b61e2f
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import torch import torch
from .compressor import Pruner from .compressor import Pruner
__all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ] __all__ = ['LevelPruner', 'AGP_Pruner', 'SensitivityPruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
...@@ -10,6 +10,7 @@ logger = logging.getLogger('torch pruner') ...@@ -10,6 +10,7 @@ logger = logging.getLogger('torch pruner')
class LevelPruner(Pruner): class LevelPruner(Pruner):
"""Prune to an exact pruning level specification """Prune to an exact pruning level specification
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
config_list: supported keys: config_list: supported keys:
...@@ -21,9 +22,9 @@ class LevelPruner(Pruner): ...@@ -21,9 +22,9 @@ class LevelPruner(Pruner):
w_abs = weight.abs() w_abs = weight.abs()
k = int(weight.numel() * config['sparsity']) k = int(weight.numel() * config['sparsity'])
if k == 0: if k == 0:
return torch.ones(weight.shape) return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
return torch.gt(w_abs, threshold).type(weight.type()) return torch.gt(w_abs, threshold).type_as(weight)
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
...@@ -35,12 +36,13 @@ class AGP_Pruner(Pruner): ...@@ -35,12 +36,13 @@ class AGP_Pruner(Pruner):
Learning of Phones and other Consumer Devices, Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf https://arxiv.org/pdf/1710.01878.pdf
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
config_list: supported keys: config_list: supported keys:
- initial_sparsity - initial_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity - final_sparsity: you should make sure initial_sparsity <= final_sparsity
- start_epoch: start epoch numer begin update mask - start_epoch: start epoch number begin update mask
- end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch - end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch
- frequency: if you want update every 2 epoch, you can set it 2 - frequency: if you want update every 2 epoch, you can set it 2
""" """
...@@ -49,15 +51,15 @@ class AGP_Pruner(Pruner): ...@@ -49,15 +51,15 @@ class AGP_Pruner(Pruner):
self.now_epoch = 1 self.now_epoch = 1
def calc_mask(self, weight, config, op_name, **kwargs): def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape)) mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
target_sparsity = self.compute_target_sparsity(config) target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity) k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask return mask
# if we want to generate new mask, we should update weigth first # if we want to generate new mask, we should update weigth first
w_abs = weight.abs()*mask w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type(weight.type()) new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask self.mask_list[op_name] = new_mask
return new_mask return new_mask
...@@ -74,11 +76,11 @@ class AGP_Pruner(Pruner): ...@@ -74,11 +76,11 @@ class AGP_Pruner(Pruner):
if end_epoch <= self.now_epoch: if end_epoch <= self.now_epoch:
return final_sparsity return final_sparsity
span = ((end_epoch - start_epoch-1)//freq)*freq span = ((end_epoch - start_epoch - 1) // freq) * freq
assert span > 0 assert span > 0
target_sparsity = (final_sparsity + target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity)* (initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch)/span))**3) (1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
return target_sparsity return target_sparsity
def update_epoch(self, epoch): def update_epoch(self, epoch):
...@@ -93,6 +95,7 @@ class SensitivityPruner(Pruner): ...@@ -93,6 +95,7 @@ class SensitivityPruner(Pruner):
I.e.: "The pruning threshold is chosen as a quality parameter multiplied I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights." by the standard deviation of a layers weights."
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
config_list: supported keys: config_list: supported keys:
...@@ -101,18 +104,17 @@ class SensitivityPruner(Pruner): ...@@ -101,18 +104,17 @@ class SensitivityPruner(Pruner):
super().__init__(config_list) super().__init__(config_list)
self.mask_list = {} self.mask_list = {}
def calc_mask(self, weight, config, op_name, **kwargs): def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape)) mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
# if we want to generate new mask, we should update weigth first # if we want to generate new mask, we should update weight first
weight = weight*mask weight = weight * mask
target_sparsity = config['sparsity'] * torch.std(weight).item() target_sparsity = config['sparsity'] * torch.std(weight).item()
k = int(weight.numel() * target_sparsity) k = int(weight.numel() * target_sparsity)
if k == 0: if k == 0:
return mask return mask
w_abs = weight.abs() w_abs = weight.abs()
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type(weight.type()) new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask self.mask_list[op_name] = new_mask
return new_mask return new_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment