Unverified Commit 55b557f1 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fpgm algo implementation unit test (#1746)

* unit test for fpgm pruner
parent 6a5864cd
...@@ -178,17 +178,13 @@ class FPGMPruner(Pruner): ...@@ -178,17 +178,13 @@ class FPGMPruner(Pruner):
assert len(weight.shape) >= 3 assert len(weight.shape) >= 3
assert weight.shape[0] * weight.shape[1] > 2 assert weight.shape[0] * weight.shape[1] > 2
dist_list, idx_list = [], [] dist_list = []
for in_i in range(weight.shape[0]): for in_i in range(weight.shape[0]):
for out_i in range(weight.shape[1]): for out_i in range(weight.shape[1]):
dist_sum = self._get_distance_sum(weight, in_i, out_i) dist_sum = self._get_distance_sum(weight, in_i, out_i)
dist_list.append(dist_sum) dist_list.append((dist_sum, (in_i, out_i)))
idx_list.append([in_i, out_i]) min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
dist_tensor = tf.convert_to_tensor(dist_list) return [x[1] for x in min_gm_kernels]
idx_tensor = tf.constant(idx_list)
_, idx = tf.math.top_k(dist_tensor, k=n)
return tf.gather(idx_tensor, idx)
def _get_distance_sum(self, weight, in_idx, out_idx): def _get_distance_sum(self, weight, in_idx, out_idx):
w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1])) w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1]))
......
from unittest import TestCase, main from unittest import TestCase, main
import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -7,11 +8,11 @@ import nni.compression.torch as torch_compressor ...@@ -7,11 +8,11 @@ import nni.compression.torch as torch_compressor
if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
import nni.compression.tensorflow as tf_compressor import nni.compression.tensorflow as tf_compressor
def get_tf_mnist_model(): def get_tf_model():
model = tf.keras.models.Sequential([ model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"), tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
tf.keras.layers.MaxPooling2D(pool_size=2), tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"), tf.keras.layers.Conv2D(filters=10, kernel_size=3, activation='relu', padding="SAME"),
tf.keras.layers.MaxPooling2D(pool_size=2), tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Flatten(), tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=128, activation='relu'), tf.keras.layers.Dense(units=128, activation='relu'),
...@@ -23,43 +24,51 @@ def get_tf_mnist_model(): ...@@ -23,43 +24,51 @@ def get_tf_mnist_model():
metrics=["accuracy"]) metrics=["accuracy"])
return model return model
class TorchMnist(torch.nn.Module): class TorchModel(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(500, 10) self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x): def forward(self, x):
x = F.relu(self.conv1(x)) x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x)) x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50) x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
def tf2(func): def tf2(func):
def test_tf2_func(self): def test_tf2_func(*args):
if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
func() func(*args)
return test_tf2_func return test_tf2_func
k1 = [[1]*3]*3
k2 = [[2]*3]*3
k3 = [[3]*3]*3
k4 = [[4]*3]*3
k5 = [[5]*3]*3
w = [[k1, k2, k3, k4, k5]] * 10
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
def test_torch_pruner(self): def test_torch_level_pruner(self):
model = TorchMnist() model = TorchModel()
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(model, configure_list).compress() torch_compressor.LevelPruner(model, configure_list).compress()
def test_torch_fpgm_pruner(self): @tf2
model = TorchMnist() def test_tf_level_pruner(self):
configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.FPGMPruner(model, configure_list).compress() tf_compressor.LevelPruner(get_tf_model(), configure_list).compress()
def test_torch_quantizer(self): def test_torch_naive_quantizer(self):
model = TorchMnist() model = TorchModel()
configure_list = [{ configure_list = [{
'quant_types': ['weight'], 'quant_types': ['weight'],
'quant_bits': { 'quant_bits': {
...@@ -70,18 +79,59 @@ class CompressorTestCase(TestCase): ...@@ -70,18 +79,59 @@ class CompressorTestCase(TestCase):
torch_compressor.NaiveQuantizer(model, configure_list).compress() torch_compressor.NaiveQuantizer(model, configure_list).compress()
@tf2 @tf2
def test_tf_pruner(self): def test_tf_naive_quantizer(self):
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] tf_compressor.NaiveQuantizer(get_tf_model(), [{'op_types': ['default']}]).compress()
tf_compressor.LevelPruner(get_tf_mnist_model(), configure_list).compress()
@tf2 def test_torch_fpgm_pruner(self):
def test_tf_quantizer(self): """
tf_compressor.NaiveQuantizer(get_tf_mnist_model(), [{'op_types': ['default']}]).compress() With filters(kernels) defined as above (k1 - k5), it is obvious that k3 is the Geometric Median
which minimize the total geometric distance by defination of Geometric Median in this paper:
Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
https://arxiv.org/pdf/1811.00250.pdf
So if sparsity is 0.2, the expected masks should mask out all k3, this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))`
If sparsity is 0.6, the expected masks should mask out all k2, k3, k4, this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))`
"""
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d']}, {'sparsity': 0.6, 'op_types': ['Conv2d']}]
pruner = torch_compressor.FPGMPruner(model, config_list)
model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))
pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))
@tf2 @tf2
def test_tf_fpgm_pruner(self): def test_tf_fpgm_pruner(self):
configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2D']}] model = get_tf_model()
tf_compressor.FPGMPruner(get_tf_mnist_model(), configure_list).compress() config_list = [{'sparsity': 0.2, 'op_types': ['Conv2D']}, {'sparsity': 0.6, 'op_types': ['Conv2D']}]
pruner = tf_compressor.FPGMPruner(model, config_list)
weights = model.layers[2].weights
weights[0] = np.array(w).astype(np.float32).transpose([2, 3, 0, 1]).transpose([0, 1, 3, 2])
model.layers[2].set_weights([weights[0], weights[1].numpy()])
layer = tf_compressor.compressor.LayerInfo(model.layers[2])
masks = pruner.calc_mask(layer, config_list[0]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3])
assert all(masks.sum((0, 2, 3)) == np.array([90., 90., 0., 90., 90.]))
pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3])
assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.]))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment