Unverified Commit 262fabf1 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Filter prune algo implementation (#1655)

* fpgm pruner pytorch and tensorflow 2.0 implementation
parent 187494aa
...@@ -12,6 +12,7 @@ We have provided two naive compression algorithms and three popular ones for use ...@@ -12,6 +12,7 @@ We have provided two naive compression algorithms and three popular ones for use
|---|---| |---|---|
| [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights | | [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights |
| [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)| | [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)|
| [FPGM Pruner](./Pruner.md#fpgm-pruner) | Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration [Reference Paper](https://arxiv.org/pdf/1811.00250.pdf)|
| [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits | | [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits |
| [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)| | [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)|
| [DoReFa Quantizer](./Quantizer.md#dorefa-quantizer) | DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. [Reference Paper](https://arxiv.org/abs/1606.06160)| | [DoReFa Quantizer](./Quantizer.md#dorefa-quantizer) | DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. [Reference Paper](https://arxiv.org/abs/1606.06160)|
......
...@@ -92,3 +92,49 @@ You can view example for more information ...@@ -92,3 +92,49 @@ You can view example for more information
*** ***
## FPGM Pruner
FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf)
>Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance.
### Usage
First, you should import pruner and add mask to model.
Tensorflow code
```python
from nni.compression.tensorflow import FPGMPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2D']
}]
pruner = FPGMPruner(model, config_list)
pruner.compress()
```
PyTorch code
```python
from nni.compression.torch import FPGMPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
pruner = FPGMPruner(model, config_list)
pruner.compress()
```
Note: FPGM Pruner is used to prune convolutional layers within deep neural networks, therefore the `op_types` field supports only convolutional layers.
Second, you should add code below to update epoch number at beginning of each epoch.
Tensorflow code
```python
pruner.update_epoch(epoch, sess)
```
PyTorch code
```python
pruner.update_epoch(epoch)
```
You can view example for more information
#### User configuration for FPGM Pruner
* **sparsity:** How much percentage of convolutional filters are to be pruned.
***
import tensorflow as tf
from tensorflow import keras
assert tf.__version__ >= "2.0"
import numpy as np
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from nni.compression.tensorflow import FPGMPruner
def get_data():
(X_train_full, y_train_full), _ = keras.datasets.mnist.load_data()
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]
X_mean = X_train.mean(axis=0, keepdims=True)
X_std = X_train.std(axis=0, keepdims=True) + 1e-7
X_train = (X_train - X_mean) / X_std
X_valid = (X_valid - X_mean) / X_std
X_train = X_train[..., np.newaxis]
X_valid = X_valid[..., np.newaxis]
return X_train, X_valid, y_train, y_valid
def get_model():
model = keras.models.Sequential([
Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
MaxPooling2D(pool_size=2),
Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"),
MaxPooling2D(pool_size=2),
Flatten(),
Dense(units=128, activation='relu'),
Dropout(0.5),
Dense(units=10, activation='softmax'),
])
model.compile(loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.SGD(lr=1e-3),
metrics=["accuracy"])
return model
def main():
X_train, X_valid, y_train, y_valid = get_data()
model = get_model()
configure_list = [{
'sparsity': 0.5,
'op_types': ['Conv2D']
}]
pruner = FPGMPruner(model, configure_list)
pruner.compress()
update_epoch_callback = keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs: pruner.update_epoch(epoch))
model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid), callbacks=[update_epoch_callback])
if __name__ == '__main__':
main()
from nni.compression.torch import FPGMPruner
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def _get_conv_weight_sparsity(self, conv_layer):
num_zero_filters = (conv_layer.weight.data.sum((2,3)) == 0).sum()
num_filters = conv_layer.weight.data.size(0) * conv_layer.weight.data.size(1)
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters
def print_conv_filter_sparsity(self):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2]))
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2]))
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
model.print_conv_filter_sparsity()
loss.backward()
optimizer.step()
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
def main():
torch.manual_seed(0)
device = torch.device('cpu')
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)
model = Mnist()
model.print_conv_filter_sparsity()
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
'''
configure_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
pruner = FPGMPruner(model, configure_list)
pruner.compress()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
if __name__ == '__main__':
main()
import logging import logging
import numpy as np
import tensorflow as tf import tensorflow as tf
from .compressor import Pruner from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner'] __all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner']
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -98,3 +99,104 @@ class AGP_Pruner(Pruner): ...@@ -98,3 +99,104 @@ class AGP_Pruner(Pruner):
sess.run(tf.assign(self.now_epoch, int(epoch))) sess.run(tf.assign(self.now_epoch, int(epoch)))
for k in self.if_init_list: for k in self.if_init_list:
self.if_init_list[k] = True self.if_init_list[k] = True
class FPGMPruner(Pruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_dict = {}
self.assign_handler = []
self.epoch_pruned_layers = set()
def calc_mask(self, layer, config):
"""
Supports Conv1D, Conv2D
filter dimensions for Conv1D:
LEN: filter length
IN: number of input channel
OUT: number of output channel
filter dimensions for Conv2D:
H: filter height
W: filter width
IN: number of input channel
OUT: number of output channel
Parameters
----------
layer : LayerInfo
calculate mask for `layer`'s weight
config : dict
the configuration for generating the mask
"""
weight = layer.weight
op_type = layer.type
op_name = layer.name
assert 0 <= config.get('sparsity') < 1
assert op_type in ['Conv1D', 'Conv2D']
assert op_type in config['op_types']
if layer.name in self.epoch_pruned_layers:
assert layer.name in self.mask_dict
return self.mask_dict.get(layer.name)
try:
weight = tf.stop_gradient(tf.transpose(weight, [2, 3, 0, 1]))
masks = np.ones(weight.shape)
num_kernels = weight.shape[0] * weight.shape[1]
num_prune = int(num_kernels * config.get('sparsity'))
if num_kernels < 2 or num_prune < 1:
return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
masks[tuple(idx)] = 0.
finally:
masks = np.transpose(masks, [2, 3, 0, 1])
masks = tf.Variable(masks)
self.mask_dict.update({op_name: masks})
self.epoch_pruned_layers.add(layer.name)
return masks
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.shape) >= 3
assert weight.shape[0] * weight.shape[1] > 2
dist_list, idx_list = [], []
for in_i in range(weight.shape[0]):
for out_i in range(weight.shape[1]):
dist_sum = self._get_distance_sum(weight, in_i, out_i)
dist_list.append(dist_sum)
idx_list.append([in_i, out_i])
dist_tensor = tf.convert_to_tensor(dist_list)
idx_tensor = tf.constant(idx_list)
_, idx = tf.math.top_k(dist_tensor, k=n)
return tf.gather(idx_tensor, idx)
def _get_distance_sum(self, weight, in_idx, out_idx):
w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1]))
anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0], 1, 1])
x = w - anchor_w
x = tf.math.reduce_sum((x*x), (-2, -1))
x = tf.math.sqrt(x)
return tf.math.reduce_sum(x)
def update_epoch(self, epoch):
self.epoch_pruned_layers = set()
import logging import logging
import tensorflow as tf import tensorflow as tf
from . import default_layers from . import default_layers
tf.config.experimental_run_functions_eagerly(True)
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class LayerInfo: class LayerInfo:
def __init__(self, op, weight, weight_op): def __init__(self, keras_layer):
self.op = op self.keras_layer = keras_layer
self.name = op.name self.name = keras_layer.name
self.type = op.type self.type = default_layers.get_op_type(type(keras_layer))
self.weight = weight self.weight_index = default_layers.get_weight_index(self.type)
self.weight_op = weight_op if self.weight_index is not None:
self.weight = keras_layer.weights[self.weight_index]
self._call = None
class Compressor: class Compressor:
""" """
...@@ -25,7 +27,7 @@ class Compressor: ...@@ -25,7 +27,7 @@ class Compressor:
Parameters Parameters
---------- ----------
model : pytorch model model : keras model
the model user wants to compress the model user wants to compress
config_list : list config_list : list
the configurations that users specify for compression the configurations that users specify for compression
...@@ -34,6 +36,21 @@ class Compressor: ...@@ -34,6 +36,21 @@ class Compressor:
self.config_list = config_list self.config_list = config_list
self.modules_to_compress = [] self.modules_to_compress = []
def detect_modules_to_compress(self):
"""
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
"""
if self.modules_to_compress is None:
self.modules_to_compress = []
for keras_layer in self.bound_model.layers:
layer = LayerInfo(keras_layer)
config = self.select_config(layer)
if config is not None:
self.modules_to_compress.append((layer, config))
return self.modules_to_compress
def compress(self): def compress(self):
""" """
Compress the model with algorithm implemented by subclass. Compress the model with algorithm implemented by subclass.
...@@ -41,19 +58,9 @@ class Compressor: ...@@ -41,19 +58,9 @@ class Compressor:
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers `self.modules_to_compress` records all the to-be-compressed layers
""" """
for op in self.bound_model.get_operations(): modules_to_compress = self.detect_modules_to_compress()
weight_index = _detect_weight_index(op) for layer, config in modules_to_compress:
if weight_index is None:
_logger.warning('Failed to detect weight for layer %s', op.name)
return
weight_op = op.inputs[weight_index].op
weight = weight_op.inputs[0]
layer = LayerInfo(op, weight, weight_op)
config = self.select_config(layer)
if config is not None:
self._instrument_layer(layer, config) self._instrument_layer(layer, config)
self.modules_to_compress.append((layer, config))
return self.bound_model return self.bound_model
def get_modules_to_compress(self): def get_modules_to_compress(self):
...@@ -74,7 +81,7 @@ class Compressor: ...@@ -74,7 +81,7 @@ class Compressor:
Parameters Parameters
---------- ----------
layer : LayerInfo layer: LayerInfo
one layer one layer
Returns Returns
...@@ -84,11 +91,12 @@ class Compressor: ...@@ -84,11 +91,12 @@ class Compressor:
not be compressed not be compressed
""" """
ret = None ret = None
if layer.type is None:
return None
for config in self.config_list: for config in self.config_list:
op_types = config.get('op_types') config = config.copy()
if op_types == 'default': config['op_types'] = self._expand_config_op_types(config)
op_types = default_layers.op_weight_index.keys() if layer.type not in config['op_types']:
if op_types and layer.type not in op_types:
continue continue
if config.get('op_names') and layer.name not in config['op_names']: if config.get('op_names') and layer.name not in config['op_names']:
continue continue
...@@ -97,7 +105,7 @@ class Compressor: ...@@ -97,7 +105,7 @@ class Compressor:
return None return None
return ret return ret
def update_epoch(self, epoch, sess): def update_epoch(self, epoch):
""" """
If user want to update model every epoch, user can override this method. If user want to update model every epoch, user can override this method.
This method should be called at the beginning of each epoch This method should be called at the beginning of each epoch
...@@ -108,7 +116,7 @@ class Compressor: ...@@ -108,7 +116,7 @@ class Compressor:
the current epoch number the current epoch number
""" """
def step(self, sess): def step(self):
""" """
If user want to update mask every step, user can override this method If user want to update mask every step, user can override this method
""" """
...@@ -127,6 +135,18 @@ class Compressor: ...@@ -127,6 +135,18 @@ class Compressor:
""" """
raise NotImplementedError() raise NotImplementedError()
def _expand_config_op_types(self, config):
if config is None:
return []
op_types = []
for op_type in config.get('op_types', []):
if op_type == 'default':
op_types.extend(default_layers.default_layers)
else:
op_types.append(op_type)
return op_types
class Pruner(Compressor): class Pruner(Compressor):
""" """
...@@ -160,10 +180,17 @@ class Pruner(Compressor): ...@@ -160,10 +180,17 @@ class Pruner(Compressor):
config : dict config : dict
the configuration for generating the mask the configuration for generating the mask
""" """
layer._call = layer.keras_layer.call
def new_call(*inputs):
weights = [x.numpy() for x in layer.keras_layer.weights]
mask = self.calc_mask(layer, config) mask = self.calc_mask(layer, config)
new_weight = layer.weight * mask weights[layer.weight_index] = weights[layer.weight_index] * mask
tf.contrib.graph_editor.swap_outputs(layer.weight_op, new_weight.op) layer.keras_layer.set_weights(weights)
ret = layer._call(*inputs)
return ret
layer.keras_layer.call = new_call
class Quantizer(Compressor): class Quantizer(Compressor):
""" """
...@@ -172,23 +199,3 @@ class Quantizer(Compressor): ...@@ -172,23 +199,3 @@ class Quantizer(Compressor):
def quantize_weight(self, weight, config, op, op_type, op_name): def quantize_weight(self, weight, config, op, op_type, op_name):
raise NotImplementedError("Quantizer must overload quantize_weight()") raise NotImplementedError("Quantizer must overload quantize_weight()")
def _instrument_layer(self, layer, config):
weight_index = _detect_weight_index(layer)
if weight_index is None:
_logger.warning('Failed to detect weight for layer %s', layer.name)
return
weight_op = layer.op.inputs[weight_index].op
weight = weight_op.inputs[0]
new_weight = self.quantize_weight(weight, config, op=layer.op, op_type=layer.type, op_name=layer.name)
tf.contrib.graph_editor.swap_outputs(weight_op, new_weight.op)
def _detect_weight_index(layer):
index = default_layers.op_weight_index.get(layer.type)
if index is not None:
return index
weight_indices = [i for i, op in enumerate(layer.inputs) if op.name.endswith('Variable/read')]
if len(weight_indices) == 1:
return weight_indices[0]
return None
op_weight_index = { from tensorflow import keras
'Conv2D': None,
'Conv3D': None,
'DepthwiseConv2dNative': None,
'Mul': None, supported_layers = {
'MatMul': None, keras.layers.Conv1D: ('Conv1D', 0),
keras.layers.Conv2D: ('Conv2D', 0),
keras.layers.Conv2DTranspose: ('Conv2DTranspose', 0),
keras.layers.Conv3D: ('Conv3D', 0),
keras.layers.Conv3DTranspose: ('Conv3DTranspose', 0),
keras.layers.ConvLSTM2D: ('ConvLSTM2D', 0),
keras.layers.Dense: ('Dense', 0),
keras.layers.Embedding: ('Embedding', 0),
keras.layers.GRU: ('GRU', 0),
keras.layers.LSTM: ('LSTM', 0),
} }
default_layers = [x[0] for x in supported_layers.values()]
def get_op_type(layer_type):
if layer_type in supported_layers:
return supported_layers[layer_type][0]
else:
return None
def get_weight_index(op_type):
for k in supported_layers:
if supported_layers[k][0] == op_type:
return supported_layers[k][1]
return None
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import torch import torch
from .compressor import Pruner from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner'] __all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
...@@ -106,3 +106,125 @@ class AGP_Pruner(Pruner): ...@@ -106,3 +106,125 @@ class AGP_Pruner(Pruner):
self.now_epoch = epoch self.now_epoch = epoch
for k in self.if_init_list: for k in self.if_init_list:
self.if_init_list[k] = True self.if_init_list[k] = True
class FPGMPruner(Pruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_dict = {}
self.epoch_pruned_layers = set()
def calc_mask(self, layer, config):
"""
Supports Conv1d, Conv2d
filter dimensions for Conv1d:
OUT: number of output channel
IN: number of input channel
LEN: filter length
filter dimensions for Conv2d:
OUT: number of output channel
IN: number of input channel
H: filter height
W: filter width
Parameters
----------
layer : LayerInfo
calculate mask for `layer`'s weight
config : dict
the configuration for generating the mask
"""
weight = layer.module.weight.data
assert 0 <= config.get('sparsity') < 1
assert layer.type in ['Conv1d', 'Conv2d']
assert layer.type in config['op_types']
if layer.name in self.epoch_pruned_layers:
assert layer.name in self.mask_dict
return self.mask_dict.get(layer.name)
masks = torch.ones(weight.size()).type_as(weight)
try:
num_kernels = weight.size(0) * weight.size(1)
num_prune = int(num_kernels * config.get('sparsity'))
if num_kernels < 2 or num_prune < 1:
return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
masks[idx] = 0.
finally:
self.mask_dict.update({layer.name: masks})
self.epoch_pruned_layers.add(layer.name)
return masks
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.size()) in [3, 4]
dist_list = []
for out_i in range(weight.size(0)):
for in_i in range(weight.size(1)):
dist_sum = self._get_distance_sum(weight, out_i, in_i)
dist_list.append((dist_sum, (out_i, in_i)))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx, in_idx):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
in_idx: int
input channel index of specified filter
Returns
-------
float32
The total distance
"""
logger.debug('weight size: %s', weight.size())
if len(weight.size()) == 4: # Conv2d
w = weight.view(-1, weight.size(-2), weight.size(-1))
anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2))
elif len(weight.size()) == 3: # Conv1d
w = weight.view(-1, weight.size(-1))
anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1))
else:
raise RuntimeError('unsupported layer type')
x = w - anchor_w
x = (x*x).sum((-2, -1))
x = torch.sqrt(x)
return x.sum()
def update_epoch(self, epoch):
self.epoch_pruned_layers = set()
...@@ -91,6 +91,7 @@ class Compressor: ...@@ -91,6 +91,7 @@ class Compressor:
""" """
ret = None ret = None
for config in self.config_list: for config in self.config_list:
config = config.copy()
config['op_types'] = self._expand_config_op_types(config) config['op_types'] = self._expand_config_op_types(config)
if layer.type not in config['op_types']: if layer.type not in config['op_types']:
continue continue
......
...@@ -2,81 +2,26 @@ from unittest import TestCase, main ...@@ -2,81 +2,26 @@ from unittest import TestCase, main
import tensorflow as tf import tensorflow as tf
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import nni.compression.tensorflow as tf_compressor
import nni.compression.torch as torch_compressor import nni.compression.torch as torch_compressor
if tf.__version__ >= '2.0':
def weight_variable(shape): import nni.compression.tensorflow as tf_compressor
return tf.Variable(tf.truncated_normal(shape, stddev=0.1))
def get_tf_mnist_model():
model = tf.keras.models.Sequential([
def bias_variable(shape): tf.keras.layers.Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
return tf.Variable(tf.constant(0.1, shape=shape)) tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"),
tf.keras.layers.MaxPooling2D(pool_size=2),
def conv2d(x_input, w_matrix): tf.keras.layers.Flatten(),
return tf.nn.conv2d(x_input, w_matrix, strides=[1, 1, 1, 1], padding='SAME') tf.keras.layers.Dense(units=128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(units=10, activation='softmax'),
def max_pool(x_input, pool_size): ])
size = [1, pool_size, pool_size, 1] model.compile(loss="sparse_categorical_crossentropy",
return tf.nn.max_pool(x_input, ksize=size, strides=size, padding='SAME') optimizer=tf.keras.optimizers.SGD(lr=1e-3),
metrics=["accuracy"])
return model
class TfMnist:
def __init__(self):
images = tf.placeholder(tf.float32, [None, 784], name='input_x')
labels = tf.placeholder(tf.float32, [None, 10], name='input_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
self.images = images
self.labels = labels
self.keep_prob = keep_prob
self.train_step = None
self.accuracy = None
self.w1 = None
self.b1 = None
self.fcw1 = None
self.cross = None
with tf.name_scope('reshape'):
x_image = tf.reshape(images, [-1, 28, 28, 1])
with tf.name_scope('conv1'):
w_conv1 = weight_variable([5, 5, 1, 32])
self.w1 = w_conv1
b_conv1 = bias_variable([32])
self.b1 = b_conv1
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
with tf.name_scope('pool1'):
h_pool1 = max_pool(h_conv1, 2)
with tf.name_scope('conv2'):
w_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
with tf.name_scope('pool2'):
h_pool2 = max_pool(h_conv2, 2)
with tf.name_scope('fc1'):
w_fc1 = weight_variable([7 * 7 * 64, 1024])
self.fcw1 = w_fc1
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
with tf.name_scope('dropout'):
h_fc1_drop = tf.nn.dropout(h_fc1, 0.5)
with tf.name_scope('fc2'):
w_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2
with tf.name_scope('loss'):
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y_conv))
self.cross = cross_entropy
with tf.name_scope('adam_optimizer'):
self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
with tf.name_scope('accuracy'):
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
class TorchMnist(torch.nn.Module): class TorchMnist(torch.nn.Module):
def __init__(self): def __init__(self):
...@@ -96,22 +41,23 @@ class TorchMnist(torch.nn.Module): ...@@ -96,22 +41,23 @@ class TorchMnist(torch.nn.Module):
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
def tf2(func):
def test_tf2_func(self):
if tf.__version__ >= '2.0':
func()
return test_tf2_func
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
def test_tf_pruner(self):
model = TfMnist()
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
tf_compressor.LevelPruner(tf.get_default_graph(), configure_list).compress()
def test_tf_quantizer(self):
model = TfMnist()
tf_compressor.NaiveQuantizer(tf.get_default_graph(), [{'op_types': ['default']}]).compress()
def test_torch_pruner(self): def test_torch_pruner(self):
model = TorchMnist() model = TorchMnist()
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(model, configure_list).compress() torch_compressor.LevelPruner(model, configure_list).compress()
def test_torch_fpgm_pruner(self):
model = TorchMnist()
configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}]
torch_compressor.FPGMPruner(model, configure_list).compress()
def test_torch_quantizer(self): def test_torch_quantizer(self):
model = TorchMnist() model = TorchMnist()
configure_list = [{ configure_list = [{
...@@ -123,6 +69,20 @@ class CompressorTestCase(TestCase): ...@@ -123,6 +69,20 @@ class CompressorTestCase(TestCase):
}] }]
torch_compressor.NaiveQuantizer(model, configure_list).compress() torch_compressor.NaiveQuantizer(model, configure_list).compress()
@tf2
def test_tf_pruner(self):
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
tf_compressor.LevelPruner(get_tf_mnist_model(), configure_list).compress()
@tf2
def test_tf_quantizer(self):
tf_compressor.NaiveQuantizer(get_tf_mnist_model(), [{'op_types': ['default']}]).compress()
@tf2
def test_tf_fpgm_pruner(self):
configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2D']}]
tf_compressor.FPGMPruner(get_tf_mnist_model(), configure_list).compress()
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment