"test/training_service/config/integration_tests.yml" did not exist on "f82ef623c1b813e7676849f885c12bcdc98d2a8e"
Unverified Commit e9f3cddf authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

AutoML for model compression (#2573)

parent 3757cf27
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch.nn as nn
import math
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
class MobileNet(nn.Module):
def __init__(self, n_class, profile='normal'):
super(MobileNet, self).__init__()
# original
if profile == 'normal':
in_planes = 32
cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024]
# 0.5 AMC
elif profile == '0.5flops':
in_planes = 24
cfg = [48, (96, 2), 80, (192, 2), 200, (328, 2), 352, 368, 360, 328, 400, (736, 2), 752]
else:
raise NotImplementedError
self.conv1 = conv_bn(3, in_planes, stride=2)
self.features = self._make_layers(in_planes, cfg, conv_dw)
self.classifier = nn.Sequential(
nn.Linear(cfg[-1], n_class),
)
self._initialize_weights()
def forward(self, x):
x = self.conv1(x)
x = self.features(x)
x = x.mean(3).mean(2) # global average pooling
x = self.classifier(x)
return x
def _make_layers(self, in_planes, cfg, layer):
layers = []
for x in cfg:
out_planes = x if isinstance(x, int) else x[0]
stride = 1 if isinstance(x, int) else x[1]
layers.append(layer(in_planes, out_planes, stride))
in_planes = out_planes
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
......@@ -5,11 +5,14 @@ import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import math
from unittest import TestCase, main
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
L2FilterPruner, AGPPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner, \
TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, AutoCompressPruner
TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, \
AutoCompressPruner, AMCPruner
from models.pytorch_models.mobilenet import MobileNet
def validate_sparsity(wrapper, sparsity, bias=False):
masks = [wrapper.weight_mask]
......@@ -154,6 +157,12 @@ prune_config = {
'evaluator': lambda model: 0.9,
'dummy_input': torch.randn([64, 1, 28, 28]),
'validators': []
},
'amc': {
'pruner_class': AMCPruner,
'config_list':[{
'op_types': ['Conv2d', 'Linear']
}]
}
}
......@@ -244,6 +253,13 @@ def test_agp(pruning_algorithm):
# set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2)
class SimpleDataset:
def __getitem__(self, index):
return torch.randn(3, 32, 32), 1.
def __len__(self):
return 1000
class PrunerTestCase(TestCase):
def test_pruners(self):
pruners_test(bias=True)
......@@ -259,5 +275,15 @@ class PrunerTestCase(TestCase):
prune_config['agp']['config_list'][0]['op_types'] = ['default']
test_agp(pruning_algorithm)
def testAMC(self):
model = MobileNet(n_class=10)
def validate(val_loader, model):
return 80.
val_loader = torch.utils.data.DataLoader(SimpleDataset(), batch_size=16, shuffle=False, drop_last=True)
config_list = prune_config['amc']['config_list']
pruner = AMCPruner(model, config_list, validate, val_loader, train_episode=1)
pruner.compress()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment