Unverified Commit bd7edf36 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

fix speedup issue (#2447)

parent f8627a2f
......@@ -163,9 +163,11 @@ class ModelSpeedup:
first, do mask/shape inference,
second, replace modules
"""
training = self.bound_model.training
_logger.info("start to speed up the model")
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info("replace compressed modules...")
self.replace_compressed_modules()
self.bound_model.train(training)
_logger.info("speedup done")
......@@ -10,9 +10,11 @@ from torchvision.models.vgg import vgg16
from torchvision.models.resnet import resnet18
from unittest import TestCase, main
from nni.compression.torch import L1FilterPruner
from nni.compression.torch import L1FilterPruner, apply_compression_results
from nni.compression.speedup.torch import ModelSpeedup
torch.manual_seed(0)
class BackboneModel1(nn.Module):
def __init__(self):
super().__init__()
......@@ -58,7 +60,10 @@ class BigModel(torch.nn.Module):
x = self.fc3(x)
return x
dummy_input = torch.randn(2, 1, 28, 28)
SPARSITY = 0.5
MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth'
def prune_model_l1(model):
config_list = [{
'sparsity': SPARSITY,
......@@ -66,14 +71,14 @@ def prune_model_l1(model):
}]
pruner = L1FilterPruner(model, config_list)
pruner.compress()
pruner.export_model(model_path='./11_model.pth', mask_path='./l1_mask.pth')
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self):
prune_model_l1(vgg16())
model = vgg16()
model.train()
ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), './l1_mask.pth')
ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), MASK_FILE)
ms.speedup_model()
orig_model = vgg16()
......@@ -88,20 +93,33 @@ class SpeedupTestCase(TestCase):
def test_speedup_bigmodel(self):
prune_model_l1(BigModel())
model = BigModel()
apply_compression_results(model, MASK_FILE, 'cpu')
model.eval()
mask_out = model(dummy_input)
model.train()
ms = ModelSpeedup(model, torch.randn(2, 1, 28, 28), './l1_mask.pth')
ms = ModelSpeedup(model, dummy_input, MASK_FILE)
ms.speedup_model()
assert model.training
model.eval()
speedup_out = model(dummy_input)
if not torch.allclose(mask_out, speedup_out, atol=1e-07):
print('input:', dummy_input.size(), torch.abs(dummy_input).sum((2,3)))
print('mask_out:', mask_out)
print('speedup_out:', speedup_out)
raise RuntimeError('model speedup inference result is incorrect!')
orig_model = BigModel()
assert model.training
assert model.backbone2.conv1.out_channels == int(orig_model.backbone2.conv1.out_channels * SPARSITY)
assert model.backbone2.conv2.in_channels == int(orig_model.backbone2.conv2.in_channels * SPARSITY)
assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY)
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
def tearDown(self):
os.remove('./11_model.pth')
os.remove('./l1_mask.pth')
os.remove(MODEL_FILE)
os.remove(MASK_FILE)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment