Commit 503a3579 authored by Tang Lang's avatar Tang Lang Committed by chicm-ms
Browse files

add pruner unit test (#1771)

* add pruner unit test
* modify pruners compatible with torch0.4.1
parent 8ac61b77
...@@ -34,6 +34,6 @@ We implemented one of the experiments in ['Learning Efficient Convolutional Netw ...@@ -34,6 +34,6 @@ We implemented one of the experiments in ['Learning Efficient Convolutional Netw
| Model | Error(paper/ours) | Parameters | Pruned | | Model | Error(paper/ours) | Parameters | Pruned |
| ------------- | ----------------- | ---------- | --------- | | ------------- | ----------------- | ---------- | --------- |
| VGGNet | 6.34/6.40 | 20.04M | | | VGGNet | 6.34/6.40 | 20.04M | |
| Pruned-VGGNet | 6.20/6.39 | 2.03M | 88.5% | | Pruned-VGGNet | 6.20/6.26 | 2.03M | 88.5% |
The experiments code can be found at [examples/model_compress]( https://github.com/microsoft/nni/tree/master/examples/model_compress/) The experiments code can be found at [examples/model_compress]( https://github.com/microsoft/nni/tree/master/examples/model_compress/)
...@@ -169,7 +169,7 @@ def main(): ...@@ -169,7 +169,7 @@ def main():
new_model.to(device) new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth')) new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth'))
test(new_model, device, test_loader) test(new_model, device, test_loader)
# top1 = 93.61% # top1 = 93.74%
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -47,7 +47,7 @@ class LevelPruner(Pruner): ...@@ -47,7 +47,7 @@ class LevelPruner(Pruner):
k = int(weight.numel() * config['sparsity']) k = int(weight.numel() * config['sparsity'])
if k == 0: if k == 0:
return torch.ones(weight.shape).type_as(weight) return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs, threshold).type_as(weight) mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_dict.update({op_name: mask}) self.mask_dict.update({op_name: mask})
self.if_init_list.update({op_name: False}) self.if_init_list.update({op_name: False})
...@@ -108,7 +108,7 @@ class AGP_Pruner(Pruner): ...@@ -108,7 +108,7 @@ class AGP_Pruner(Pruner):
return mask return mask
# if we want to generate new mask, we should update weigth first # if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = torch.gt(w_abs, threshold).type_as(weight) new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_dict.update({op_name: new_mask}) self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False}) self.if_init_list.update({op_name: False})
...@@ -336,7 +336,7 @@ class L1FilterPruner(Pruner): ...@@ -336,7 +336,7 @@ class L1FilterPruner(Pruner):
if k == 0: if k == 0:
return torch.ones(weight.shape).type_as(weight) return torch.ones(weight.shape).type_as(weight)
w_abs_structured = w_abs.view(filters, -1).sum(dim=1) w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False).values.max() threshold = torch.topk(w_abs_structured.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight) mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
finally: finally:
self.mask_dict.update({layer.name: mask}) self.mask_dict.update({layer.name: mask})
...@@ -370,10 +370,10 @@ class SlimPruner(Pruner): ...@@ -370,10 +370,10 @@ class SlimPruner(Pruner):
config = config_list[0] config = config_list[0]
for (layer, config) in self.detect_modules_to_compress(): for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.clone()) weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list) all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity']) k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False).values.max() self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
""" """
......
...@@ -8,6 +8,7 @@ import nni.compression.torch as torch_compressor ...@@ -8,6 +8,7 @@ import nni.compression.torch as torch_compressor
if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
import nni.compression.tensorflow as tf_compressor import nni.compression.tensorflow as tf_compressor
def get_tf_model(): def get_tf_model():
model = tf.keras.models.Sequential([ model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"), tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
...@@ -24,38 +25,45 @@ def get_tf_model(): ...@@ -24,38 +25,45 @@ def get_tf_model():
metrics=["accuracy"]) metrics=["accuracy"])
return model return model
class TorchModel(torch.nn.Module): class TorchModel(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1) self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1) self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100) self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10) self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x): def forward(self, x):
x = F.relu(self.conv1(x)) x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x)) x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10) x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
def tf2(func): def tf2(func):
def test_tf2_func(*args): def test_tf2_func(*args):
if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
func(*args) func(*args)
return test_tf2_func return test_tf2_func
k1 = [[1]*3]*3
k2 = [[2]*3]*3 k1 = [[1] * 3] * 3
k3 = [[3]*3]*3 k2 = [[2] * 3] * 3
k4 = [[4]*3]*3 k3 = [[3] * 3] * 3
k5 = [[5]*3]*3 k4 = [[4] * 3] * 3
k5 = [[5] * 3] * 3
w = [[k1, k2, k3, k4, k5]] * 10 w = [[k1, k2, k3, k4, k5]] * 10
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
def test_torch_level_pruner(self): def test_torch_level_pruner(self):
model = TorchModel() model = TorchModel()
...@@ -74,7 +82,7 @@ class CompressorTestCase(TestCase): ...@@ -74,7 +82,7 @@ class CompressorTestCase(TestCase):
'quant_bits': { 'quant_bits': {
'weight': 8, 'weight': 8,
}, },
'op_types':['Conv2d', 'Linear'] 'op_types': ['Conv2d', 'Linear']
}] }]
torch_compressor.NaiveQuantizer(model, configure_list).compress() torch_compressor.NaiveQuantizer(model, configure_list).compress()
...@@ -133,6 +141,73 @@ class CompressorTestCase(TestCase): ...@@ -133,6 +141,73 @@ class CompressorTestCase(TestCase):
assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.])) assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.]))
def test_torch_l1filter_pruner(self):
"""
Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
PRUNING FILTERS FOR EFFICIENT CONVNETS,
https://arxiv.org/abs/1608.08710
So if sparsity is 0.2, the expected masks should mask out filter 0, this can be verified through:
`all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))`
If sparsity is 0.6, the expected masks should mask out filter 0,1,2, this can be verified through:
`all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))`
"""
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_names': ['conv1']}, {'sparsity': 0.6, 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list)
model.conv1.weight.data = torch.tensor(w).float()
model.conv2.weight.data = torch.tensor(w).float()
layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1])
assert all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
def test_torch_slim_pruner(self):
"""
Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
Learning Efficient Convolutional Networks through Network Slimming,
https://arxiv.org/pdf/1708.06519.pdf
So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
`all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`
If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
`all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
"""
w = np.array([0, 1, 2, 3, 4])
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(-w).float()
pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))
config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(w).float()
pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment