Unverified Commit e349b440 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Add TensorFlow slim pruner (#3614)


Co-authored-by: default avatarliuzhe <zhe.liu@microsoft.com>
parent d1450b4d
...@@ -3,15 +3,15 @@ import tensorflow as tf ...@@ -3,15 +3,15 @@ import tensorflow as tf
from nni.compression.tensorflow import Pruner from nni.compression.tensorflow import Pruner
__all__ = [ __all__ = [
'OneshotPruner',
'LevelPruner', 'LevelPruner',
'SlimPruner',
] ]
class OneshotPruner(Pruner): class OneshotPruner(Pruner):
def __init__(self, model, config_list, pruning_algorithm='level', **algo_kwargs): def __init__(self, model, config_list, masker_class, **algo_kwargs):
super().__init__(model, config_list) super().__init__(model, config_list)
self.set_wrappers_attribute('calculated', False) self.set_wrappers_attribute('calculated', False)
self.masker = MASKER_DICT[pruning_algorithm](model, self, **algo_kwargs) self.masker = masker_class(model, self, **algo_kwargs)
def validate_config(self, model, config_list): def validate_config(self, model, config_list):
pass # TODO pass # TODO
...@@ -28,7 +28,12 @@ class OneshotPruner(Pruner): ...@@ -28,7 +28,12 @@ class OneshotPruner(Pruner):
class LevelPruner(OneshotPruner): class LevelPruner(OneshotPruner):
def __init__(self, model, config_list): def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='level') super().__init__(model, config_list, LevelPrunerMasker)
class SlimPruner(OneshotPruner):
def __init__(self, model, config_list):
super().__init__(model, config_list, SlimPrunerMasker)
class WeightMasker: class WeightMasker:
...@@ -57,7 +62,7 @@ class LevelPrunerMasker(WeightMasker): ...@@ -57,7 +62,7 @@ class LevelPrunerMasker(WeightMasker):
w_abs = tf.math.abs(weight) w_abs = tf.math.abs(weight)
k = tf.size(weight) - num_prune k = tf.size(weight) - num_prune
topk = tf.math.top_k(tf.reshape(w_abs, [-1]), k)[0] topk = tf.math.top_k(tf.reshape(w_abs, [-1]), k).values
if tf.size(topk) == 0: if tf.size(topk) == 0:
mask = tf.zeros_like(weight) mask = tf.zeros_like(weight)
else: else:
...@@ -65,7 +70,41 @@ class LevelPrunerMasker(WeightMasker): ...@@ -65,7 +70,41 @@ class LevelPrunerMasker(WeightMasker):
masks[weight_variable.name] = tf.cast(mask, weight.dtype) masks[weight_variable.name] = tf.cast(mask, weight.dtype)
return masks return masks
class SlimPrunerMasker(WeightMasker):
def __init__(self, model, pruner, **kwargs):
super().__init__(model, pruner)
weight_list = []
for wrapper in pruner.wrappers:
weights = [w for w in wrapper.layer.weights if '/gamma:' in w.name]
assert len(weights) == 1, f'Bad weights: {[w.name for w in wrapper.layer.weights]}'
weight_list.append(tf.math.abs(weights[0].read_value()))
all_bn_weights = tf.concat(weight_list, 0)
k = int(all_bn_weights.shape[0] * pruner.wrappers[0].config['sparsity'])
top_k = -tf.math.top_k(-tf.reshape(all_bn_weights, [-1]), k).values
self.global_threshold = top_k.numpy()[-1]
def calc_masks(self, sparsity, wrapper, wrapper_idx=None):
assert isinstance(wrapper.layer, tf.keras.layers.BatchNormalization), \
'SlimPruner only supports 2D batch normalization layer pruning'
MASKER_DICT = { weight = None
'level': LevelPrunerMasker, weight_name = None
} bias_name = None
for variable in wrapper.layer.weights:
if '/gamma:' in variable.name:
weight = variable.read_value()
weight_name = variable.name
elif '/beta:' in variable.name:
bias_name = variable.name
assert weight is not None
if wrapper.masks.get(weight_name) is not None:
weight *= wrapper.masks[weight_name]
mask = tf.cast(tf.math.abs(weight) > self.global_threshold, weight.dtype)
masks = {weight_name: mask}
if bias_name:
masks[bias_name] = mask
return masks
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment