Commit fd551c86 authored by Tang Lang's avatar Tang Lang Committed by QuanluZhang
Browse files

fix builtin pruners bug (#1612)

* fix builtin pruners bug
parent d6b61e2f
......@@ -2,7 +2,7 @@ import logging
import torch
from .compressor import Pruner
__all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ]
__all__ = ['LevelPruner', 'AGP_Pruner', 'SensitivityPruner']
logger = logging.getLogger('torch pruner')
......@@ -10,6 +10,7 @@ logger = logging.getLogger('torch pruner')
class LevelPruner(Pruner):
"""Prune to an exact pruning level specification
"""
def __init__(self, config_list):
"""
config_list: supported keys:
......@@ -21,9 +22,9 @@ class LevelPruner(Pruner):
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape)
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max()
return torch.gt(w_abs, threshold).type(weight.type())
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
return torch.gt(w_abs, threshold).type_as(weight)
class AGP_Pruner(Pruner):
......@@ -35,12 +36,13 @@ class AGP_Pruner(Pruner):
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def __init__(self, config_list):
"""
config_list: supported keys:
- initial_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity
- start_epoch: start epoch numer begin update mask
- start_epoch: start epoch number begin update mask
- end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch
- frequency: if you want update every 2 epoch, you can set it 2
"""
......@@ -49,15 +51,15 @@ class AGP_Pruner(Pruner):
self.now_epoch = 1
def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape))
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs()*mask
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max()
new_mask = torch.gt(w_abs, threshold).type(weight.type())
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask
return new_mask
......@@ -74,18 +76,18 @@ class AGP_Pruner(Pruner):
if end_epoch <= self.now_epoch:
return final_sparsity
span = ((end_epoch - start_epoch-1)//freq)*freq
span = ((end_epoch - start_epoch - 1) // freq) * freq
assert span > 0
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity)*
(1.0 - ((self.now_epoch - start_epoch)/span))**3)
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
return target_sparsity
def update_epoch(self, epoch):
if epoch > 0:
self.now_epoch = epoch
class SensitivityPruner(Pruner):
"""Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
https://arxiv.org/pdf/1506.02626v3.pdf
......@@ -93,6 +95,7 @@ class SensitivityPruner(Pruner):
I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights."
"""
def __init__(self, config_list):
"""
config_list: supported keys:
......@@ -100,19 +103,18 @@ class SensitivityPruner(Pruner):
"""
super().__init__(config_list)
self.mask_list = {}
def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape))
# if we want to generate new mask, we should update weigth first
weight = weight*mask
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
# if we want to generate new mask, we should update weight first
weight = weight * mask
target_sparsity = config['sparsity'] * torch.std(weight).item()
k = int(weight.numel() * target_sparsity)
if k == 0:
return mask
w_abs = weight.abs()
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max()
new_mask = torch.gt(w_abs, threshold).type(weight.type())
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask
return new_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment