"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "356677917541d02136e829affc37db538ce33e02"
Unverified Commit bd7edf36 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

fix speedup issue (#2447)

parent f8627a2f
...@@ -163,9 +163,11 @@ class ModelSpeedup: ...@@ -163,9 +163,11 @@ class ModelSpeedup:
first, do mask/shape inference, first, do mask/shape inference,
second, replace modules second, replace modules
""" """
training = self.bound_model.training
_logger.info("start to speed up the model") _logger.info("start to speed up the model")
_logger.info("infer module masks...") _logger.info("infer module masks...")
self.infer_modules_masks() self.infer_modules_masks()
_logger.info("replace compressed modules...") _logger.info("replace compressed modules...")
self.replace_compressed_modules() self.replace_compressed_modules()
self.bound_model.train(training)
_logger.info("speedup done") _logger.info("speedup done")
...@@ -10,9 +10,11 @@ from torchvision.models.vgg import vgg16 ...@@ -10,9 +10,11 @@ from torchvision.models.vgg import vgg16
from torchvision.models.resnet import resnet18 from torchvision.models.resnet import resnet18
from unittest import TestCase, main from unittest import TestCase, main
from nni.compression.torch import L1FilterPruner from nni.compression.torch import L1FilterPruner, apply_compression_results
from nni.compression.speedup.torch import ModelSpeedup from nni.compression.speedup.torch import ModelSpeedup
torch.manual_seed(0)
class BackboneModel1(nn.Module): class BackboneModel1(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -58,7 +60,10 @@ class BigModel(torch.nn.Module): ...@@ -58,7 +60,10 @@ class BigModel(torch.nn.Module):
x = self.fc3(x) x = self.fc3(x)
return x return x
dummy_input = torch.randn(2, 1, 28, 28)
SPARSITY = 0.5 SPARSITY = 0.5
MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth'
def prune_model_l1(model): def prune_model_l1(model):
config_list = [{ config_list = [{
'sparsity': SPARSITY, 'sparsity': SPARSITY,
...@@ -66,14 +71,14 @@ def prune_model_l1(model): ...@@ -66,14 +71,14 @@ def prune_model_l1(model):
}] }]
pruner = L1FilterPruner(model, config_list) pruner = L1FilterPruner(model, config_list)
pruner.compress() pruner.compress()
pruner.export_model(model_path='./11_model.pth', mask_path='./l1_mask.pth') pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
class SpeedupTestCase(TestCase): class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self): def test_speedup_vgg16(self):
prune_model_l1(vgg16()) prune_model_l1(vgg16())
model = vgg16() model = vgg16()
model.train() model.train()
ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), './l1_mask.pth') ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), MASK_FILE)
ms.speedup_model() ms.speedup_model()
orig_model = vgg16() orig_model = vgg16()
...@@ -88,20 +93,33 @@ class SpeedupTestCase(TestCase): ...@@ -88,20 +93,33 @@ class SpeedupTestCase(TestCase):
def test_speedup_bigmodel(self): def test_speedup_bigmodel(self):
prune_model_l1(BigModel()) prune_model_l1(BigModel())
model = BigModel() model = BigModel()
apply_compression_results(model, MASK_FILE, 'cpu')
model.eval()
mask_out = model(dummy_input)
model.train() model.train()
ms = ModelSpeedup(model, torch.randn(2, 1, 28, 28), './l1_mask.pth') ms = ModelSpeedup(model, dummy_input, MASK_FILE)
ms.speedup_model() ms.speedup_model()
assert model.training
model.eval()
speedup_out = model(dummy_input)
if not torch.allclose(mask_out, speedup_out, atol=1e-07):
print('input:', dummy_input.size(), torch.abs(dummy_input).sum((2,3)))
print('mask_out:', mask_out)
print('speedup_out:', speedup_out)
raise RuntimeError('model speedup inference result is incorrect!')
orig_model = BigModel() orig_model = BigModel()
assert model.training
assert model.backbone2.conv1.out_channels == int(orig_model.backbone2.conv1.out_channels * SPARSITY) assert model.backbone2.conv1.out_channels == int(orig_model.backbone2.conv1.out_channels * SPARSITY)
assert model.backbone2.conv2.in_channels == int(orig_model.backbone2.conv2.in_channels * SPARSITY) assert model.backbone2.conv2.in_channels == int(orig_model.backbone2.conv2.in_channels * SPARSITY)
assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY) assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY)
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY) assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
def tearDown(self): def tearDown(self):
os.remove('./11_model.pth') os.remove(MODEL_FILE)
os.remove('./l1_mask.pth') os.remove(MASK_FILE)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment