test_compressor.py 9.05 KB
Newer Older
1
from unittest import TestCase, main
2
import numpy as np
3
import tensorflow as tf
4
5
import torch
import torch.nn.functional as F
6
import nni.compression.torch as torch_compressor
7

8
9
10
if tf.__version__ >= '2.0':
    import nni.compression.tensorflow as tf_compressor

Tang Lang's avatar
Tang Lang committed
11

12
def get_tf_model():
13
    model = tf.keras.models.Sequential([
14
        tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
15
        tf.keras.layers.MaxPooling2D(pool_size=2),
16
        tf.keras.layers.Conv2D(filters=10, kernel_size=3, activation='relu', padding="SAME"),
17
18
19
20
21
22
23
        tf.keras.layers.MaxPooling2D(pool_size=2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=128, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(units=10, activation='softmax'),
    ])
    model.compile(loss="sparse_categorical_crossentropy",
Tang Lang's avatar
Tang Lang committed
24
25
                  optimizer=tf.keras.optimizers.SGD(lr=1e-3),
                  metrics=["accuracy"])
26
    return model
27

Tang Lang's avatar
Tang Lang committed
28

29
class TorchModel(torch.nn.Module):
30
31
    def __init__(self):
        super().__init__()
32
        self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
Tang Lang's avatar
Tang Lang committed
33
        self.bn1 = torch.nn.BatchNorm2d(5)
34
        self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
Tang Lang's avatar
Tang Lang committed
35
        self.bn2 = torch.nn.BatchNorm2d(10)
36
37
        self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
        self.fc2 = torch.nn.Linear(100, 10)
38
39

    def forward(self, x):
Tang Lang's avatar
Tang Lang committed
40
        x = F.relu(self.bn1(self.conv1(x)))
41
        x = F.max_pool2d(x, 2, 2)
Tang Lang's avatar
Tang Lang committed
42
        x = F.relu(self.bn2(self.conv2(x)))
43
        x = F.max_pool2d(x, 2, 2)
44
        x = x.view(-1, 4 * 4 * 10)
45
46
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
47
48
        return F.log_softmax(x, dim=1)

Tang Lang's avatar
Tang Lang committed
49

50
def tf2(func):
51
    def test_tf2_func(*args):
52
        if tf.__version__ >= '2.0':
53
            func(*args)
Tang Lang's avatar
Tang Lang committed
54

55
    return test_tf2_func
56

chicm-ms's avatar
chicm-ms committed
57
58
# for fpgm filter pruner test
w = np.array([[[[i+1]*3]*3]*5 for i in range(10)])
59

Tang Lang's avatar
Tang Lang committed
60

61
class CompressorTestCase(TestCase):
62
63
    def test_torch_level_pruner(self):
        model = TorchModel()
chicm-ms's avatar
chicm-ms committed
64
        configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
QuanluZhang's avatar
QuanluZhang committed
65
        torch_compressor.LevelPruner(model, configure_list).compress()
66

67
68
69
70
    @tf2
    def test_tf_level_pruner(self):
        configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
        tf_compressor.LevelPruner(get_tf_model(), configure_list).compress()
71

72
73
    def test_torch_naive_quantizer(self):
        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
74
75
76
77
78
        configure_list = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            },
Tang Lang's avatar
Tang Lang committed
79
            'op_types': ['Conv2d', 'Linear']
Cjkkkk's avatar
Cjkkkk committed
80
81
        }]
        torch_compressor.NaiveQuantizer(model, configure_list).compress()
82

83
    @tf2
84
85
    def test_tf_naive_quantizer(self):
        tf_compressor.NaiveQuantizer(get_tf_model(), [{'op_types': ['default']}]).compress()
86

87
88
    def test_torch_fpgm_pruner(self):
        """
chicm-ms's avatar
chicm-ms committed
89
        With filters(kernels) weights defined as above (w), it is obvious that w[4] and w[5] is the Geometric Median
90
91
92
93
        which minimize the total geometric distance by defination of Geometric Median in this paper:
        Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
        https://arxiv.org/pdf/1811.00250.pdf

chicm-ms's avatar
chicm-ms committed
94
95
        So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))`
96

chicm-ms's avatar
chicm-ms committed
97
98
        If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))`
99
100
101
102
103
104
105
106
107
        """

        model = TorchModel()
        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d']}, {'sparsity': 0.6, 'op_types': ['Conv2d']}]
        pruner = torch_compressor.FPGMPruner(model, config_list)

        model.conv2.weight.data = torch.tensor(w).float()
        layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
        masks = pruner.calc_mask(layer, config_list[0])
chicm-ms's avatar
chicm-ms committed
108
        assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
109
110
111
112

        pruner.update_epoch(1)
        model.conv2.weight.data = torch.tensor(w).float()
        masks = pruner.calc_mask(layer, config_list[1])
chicm-ms's avatar
chicm-ms committed
113
        assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
114
115
116

    @tf2
    def test_tf_fpgm_pruner(self):
117
118
119
120
121
122
123
124
125
126
        model = get_tf_model()
        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2D']}, {'sparsity': 0.6, 'op_types': ['Conv2D']}]

        pruner = tf_compressor.FPGMPruner(model, config_list)
        weights = model.layers[2].weights
        weights[0] = np.array(w).astype(np.float32).transpose([2, 3, 0, 1]).transpose([0, 1, 3, 2])
        model.layers[2].set_weights([weights[0], weights[1].numpy()])

        layer = tf_compressor.compressor.LayerInfo(model.layers[2])
        masks = pruner.calc_mask(layer, config_list[0]).numpy()
chicm-ms's avatar
chicm-ms committed
127
        masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
128

chicm-ms's avatar
chicm-ms committed
129
        assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
130
131
132
133

        pruner.update_epoch(1)
        model.layers[2].set_weights([weights[0], weights[1].numpy()])
        masks = pruner.calc_mask(layer, config_list[1]).numpy()
chicm-ms's avatar
chicm-ms committed
134
135
        masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
        assert all(masks.sum((1)) == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
136

Tang Lang's avatar
Tang Lang committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def test_torch_l1filter_pruner(self):
        """
        Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
        PRUNING FILTERS FOR EFFICIENT CONVNETS,
        https://arxiv.org/abs/1608.08710

        So if sparsity is 0.2, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))`

        If sparsity is 0.6, the expected masks should mask out filter 0,1,2, this can be verified through:
        `all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))`
        """
        w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
                      np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
        model = TorchModel()
        config_list = [{'sparsity': 0.2, 'op_names': ['conv1']}, {'sparsity': 0.6, 'op_names': ['conv2']}]
        pruner = torch_compressor.L1FilterPruner(model, config_list)

        model.conv1.weight.data = torch.tensor(w).float()
        model.conv2.weight.data = torch.tensor(w).float()
        layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1)
        mask1 = pruner.calc_mask(layer1, config_list[0])
        layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
        mask2 = pruner.calc_mask(layer2, config_list[1])
        assert all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
        assert all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))

    def test_torch_slim_pruner(self):
        """
        Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
        Learning Efficient Convolutional Networks through Network Slimming,
        https://arxiv.org/pdf/1708.06519.pdf

        So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
        `all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`

        If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
        `all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
        """
        w = np.array([0, 1, 2, 3, 4])
        model = TorchModel()
        config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(-w).float()
        pruner = torch_compressor.SlimPruner(model, config_list)

        layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
        mask1 = pruner.calc_mask(layer1, config_list[0])
        layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
        mask2 = pruner.calc_mask(layer2, config_list[0])
        assert all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))

        config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(w).float()
        pruner = torch_compressor.SlimPruner(model, config_list)

        layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
        mask1 = pruner.calc_mask(layer1, config_list[0])
        layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
        mask2 = pruner.calc_mask(layer2, config_list[0])
        assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))

204
205
if __name__ == '__main__':
    main()