Unverified Commit 5eb95c2d authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix pruners (#2153)

parent e9f54647
...@@ -124,9 +124,9 @@ class L1FilterPruner(WeightRankFilterPruner): ...@@ -124,9 +124,9 @@ class L1FilterPruner(WeightRankFilterPruner):
w_abs_structured = w_abs.view(filters, -1).sum(dim=1) w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max() threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight) mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight) mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()} return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class L2FilterPruner(WeightRankFilterPruner): class L2FilterPruner(WeightRankFilterPruner):
...@@ -172,9 +172,9 @@ class L2FilterPruner(WeightRankFilterPruner): ...@@ -172,9 +172,9 @@ class L2FilterPruner(WeightRankFilterPruner):
w_l2_norm = torch.sqrt((w ** 2).sum(dim=1)) w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max() threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight) mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight) mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight).detach() if base_mask['bias_mask'] is not None else None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()} return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class FPGMPruner(WeightRankFilterPruner): class FPGMPruner(WeightRankFilterPruner):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from unittest import TestCase, main
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
L2FilterPruner, AGP_Pruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner
def validate_sparsity(wrapper, sparsity, bias=False):
masks = [wrapper.weight_mask]
if bias and wrapper.bias_mask is not None:
masks.append(wrapper.bias_mask)
for m in masks:
actual_sparsity = (m == 0).sum().item() / m.numel()
msg = 'actual sparsity: {:.2f}, target sparsity: {:.2f}'.format(actual_sparsity, sparsity)
assert math.isclose(actual_sparsity, sparsity, abs_tol=0.1), msg
prune_config = {
'level': {
'pruner_class': LevelPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['default'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, False),
lambda model: validate_sparsity(model.fc, 0.5, False)
]
},
'agp': {
'pruner_class': AGP_Pruner,
'config_list': [{
'initial_sparsity': 0,
'final_sparsity': 0.8,
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': ['default']
}],
'validators': []
},
'slim': {
'pruner_class': SlimPruner,
'config_list': [{
'sparsity': 0.7,
'op_types': ['BatchNorm2d']
}],
'validators': [
lambda model: validate_sparsity(model.bn1, 0.7, model.bias)
]
},
'fpgm': {
'pruner_class': FPGMPruner,
'config_list':[{
'sparsity': 0.5,
'op_types': ['Conv2d']
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
},
'l1': {
'pruner_class': L1FilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
},
'l2': {
'pruner_class': L2FilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
},
'mean_activation': {
'pruner_class': ActivationMeanRankFilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
},
'apoz': {
'pruner_class': ActivationAPoZRankFilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
}
}
class Model(nn.Module):
def __init__(self, bias=True):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1, bias=bias)
self.bn1 = nn.BatchNorm2d(8)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 2, bias=bias)
self.bias = bias
def forward(self, x):
return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1))
def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'mean_activation', 'apoz'], bias=True):
for pruner_name in pruner_names:
print('testing {}...'.format(pruner_name))
model = Model(bias=bias)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
config_list = prune_config[pruner_name]['config_list']
x = torch.randn(2, 1, 28, 28)
y = torch.tensor([0, 1]).long()
out = model(x)
loss = F.cross_entropy(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer)
pruner.compress()
x = torch.randn(2, 1, 28, 28)
y = torch.tensor([0, 1]).long()
out = model(x)
loss = F.cross_entropy(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,1,28,28))
for v in prune_config[pruner_name]['validators']:
v(model)
os.remove('./model_tmp.pth')
os.remove('./mask_tmp.pth')
os.remove('./onnx_tmp.pth')
class PrunerTestCase(TestCase):
def test_pruners(self):
pruners_test(bias=True)
def test_pruners_no_bias(self):
pruners_test(bias=False)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment