test_compressor_tf.py 3.39 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import unittest

import numpy as np
import tensorflow as tf


####
#
12
13
14
15
# This file tests pruners on 3 models:
#   A classic CNN model built by inheriting `Model`;
#   The same CNN model built with `Sequential`;
#   A naive model with only one linear layer.
liuzhe-lz's avatar
liuzhe-lz committed
16
#
17
# The CNN models are used to test layer detecting and instrumenting.
liuzhe-lz's avatar
liuzhe-lz committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#
# The naive model is used to test mask calculation.
# It has a single 10x10 linear layer without bias, and `reduce_sum` its result.
# To help predicting pruning result, the linear layer has fixed initial weights:
#     [ [ 0.0, 1.0, 2.0, ..., 9.0 ], [0.1, 1.1, 2.1, ..., 9.1 ], ... , [0.9, 1.0, 2.9, ..., 9.9 ] ]
#
####


# This tensor is used as input of 10x10 linear layer, the first dimension is batch size
tensor1x10 = tf.constant([[1.0] * 10])


@unittest.skipIf(tf.__version__[0] != '2', 'Skip TF 1.x setup')
class TfCompressorTestCase(unittest.TestCase):
    def test_layer_detection(self):
        # Conv and dense layers should be compressed, pool and flatten should not.
        # This also tests instrumenting functionality.
        self._test_layer_detection_on_model(CnnModel())
37
        self._test_layer_detection_on_model(build_sequential_model())
liuzhe-lz's avatar
liuzhe-lz committed
38
39
40
41

    def _test_layer_detection_on_model(self, model):
        pruner = pruners['level'](model)
        pruner.compress()
42
        layer_types = sorted(type(wrapper.layer).__name__ for wrapper in pruner.wrappers)
liuzhe-lz's avatar
liuzhe-lz committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        assert layer_types == ['Conv2D', 'Dense', 'Dense'], layer_types

    def test_level_pruner(self):
        # prune 90% : 9.0 + 9.1 + ... + 9.9 = 94.5
        model = build_naive_model()
        pruners['level'](model).compress()
        x = model(tensor1x10)
        assert x.numpy() == 94.5


try:
    from tensorflow.keras import Model, Sequential
    from tensorflow.keras.layers import (Conv2D, Dense, Flatten, MaxPool2D)

57
    from nni.algorithms.compression.tensorflow.pruning import LevelPruner
liuzhe-lz's avatar
liuzhe-lz committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    pruners = {
        'level': (lambda model: LevelPruner(model, [{'sparsity': 0.9, 'op_types': ['default']}])),
    }

    class CnnModel(Model):
        def __init__(self):
            super().__init__()
            self.conv = Conv2D(filters=10, kernel_size=3, activation='relu')
            self.pool = MaxPool2D(pool_size=2)
            self.flatten = Flatten()
            self.fc1 = Dense(units=10, activation='relu')
            self.fc2 = Dense(units=5, activation='softmax')

        def call(self, x):
            x = self.conv(x)
            x = self.pool(x)
            x = self.flatten(x)
            x = self.fc1(x)
            x = self.fc2(x)
            return x

80
81
82
83
84
85
86
87
88
    def build_sequential_model():
        return Sequential([
            Conv2D(filters=10, kernel_size=3, activation='relu'),
            MaxPool2D(pool_size=2),
            Flatten(),
            Dense(units=10, activation='relu'),
            Dense(units=5, activation='softmax'),
        ])

liuzhe-lz's avatar
liuzhe-lz committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    class NaiveModel(Model):
        def __init__(self):
            super().__init__()
            self.fc = Dense(units=10, use_bias=False)

        def call(self, x):
            return tf.math.reduce_sum(self.fc(x))

except Exception:
    pass


def build_naive_model():
    model = NaiveModel()
    model.build(tensor1x10.shape)
    weight = [[(i + j * 0.1) for i in range(10)] for j in range(10)]
    model.set_weights([np.array(weight)])
    return model


if __name__ == '__main__':
    unittest.main()