Unverified Commit 47c7ea14 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Add several speedup examples (#3880)

parent 5fe24500
...@@ -7,7 +7,7 @@ import torch.nn.functional as F ...@@ -7,7 +7,7 @@ import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
import sys import sys
sys.path.append('../models') sys.path.append('../../models')
from cifar10.vgg import VGG from cifar10.vgg import VGG
from mnist.lenet import LeNet from mnist.lenet import LeNet
......
import torch
from torchvision.models import mobilenet_v2
from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
model = mobilenet_v2(pretrained=True)
dummy_input = torch.rand(8, 3, 416, 416)
cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5}]
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
# need call _unwrap_model if you want run the speedup on the same model
pruner._unwrap_model()
# Speedup the nanodet
ms = ModelSpeedup(model, dummy_input, './mask')
ms.speedup_model()
model(dummy_input)
\ No newline at end of file
import torch
from nanodet.model.arch import build_model
from nanodet.util import cfg, load_config
from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
"""
NanoDet model can be installed from https://github.com/RangiLyu/nanodet.git
"""
cfg_path = r"nanodet/config/nanodet-RepVGG-A0_416.yml"
load_config(cfg, cfg_path)
model = build_model(cfg.model)
dummy_input = torch.rand(8, 3, 416, 416)
op_names = []
# these three conv layers are followed by reshape-like functions
# that cannot be replaced, so we skip these three conv layers,
# you can also get such layers by `not_safe_to_prune` function
excludes = ['head.gfl_cls.0', 'head.gfl_cls.1', 'head.gfl_cls.2']
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
if name not in excludes:
op_names.append(name)
cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5, 'op_names':op_names}]
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
# need call _unwrap_model if you want run the speedup on the same model
pruner._unwrap_model()
# Speedup the nanodet
ms = ModelSpeedup(model, dummy_input, './mask')
ms.speedup_model()
model(dummy_input)
\ No newline at end of file
import torch
from pytorchyolo import models
from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, LevelPruner
from nni.compression.pytorch.utils import not_safe_to_prune
# The Yolo can be downloaded at https://github.com/eriklindernoren/PyTorch-YOLOv3.git
prefix = '/home/user/PyTorch-YOLOv3' # replace this path with yours
# Load the YOLO model
model = models.load_model(
"%s/config/yolov3.cfg" % prefix,
"%s/yolov3.weights" % prefix)
model.eval()
dummy_input = torch.rand(8, 3, 320, 320)
model(dummy_input)
# Generate the config list for pruner
# Filter the layers that may not be able to prune
not_safe = not_safe_to_prune(model, dummy_input)
cfg_list = []
for name, module in model.named_modules():
if name in not_safe:
continue
if isinstance(module, torch.nn.Conv2d):
cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.6, 'op_names':[name]})
# Prune the model
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
pruner._unwrap_model()
# Speedup the model
ms = ModelSpeedup(model, dummy_input, './mask')
ms.speedup_model()
model(dummy_input)
...@@ -70,7 +70,14 @@ def randomize_tensor(tensor, start=1, end=100): ...@@ -70,7 +70,14 @@ def randomize_tensor(tensor, start=1, end=100):
def not_safe_to_prune(model, dummy_input): def not_safe_to_prune(model, dummy_input):
""" """
Get the layers that are safe to prune(will not bring the shape conflict). Get the layers that are not safe to prune(may bring the shape conflict).
For example, if the output tensor of a conv layer is directly followed by
a shape-dependent function(such as reshape/view), then this conv layer
may be not safe to be pruned. Pruning may change the output shape of
this conv layer and result in shape problems. This function find all the
layers that directly followed by the shape-dependent functions(view, reshape, etc).
If you run the inference after the speedup and run into a shape related error,
please exclude the layers returned by this function and try again.
Parameters Parameters
---------- ----------
......
...@@ -361,6 +361,10 @@ class SpeedupTestCase(TestCase): ...@@ -361,6 +361,10 @@ class SpeedupTestCase(TestCase):
self.speedup_integration(model_list) self.speedup_integration(model_list)
def speedup_integration(self, model_list, speedup_cfg=None): def speedup_integration(self, model_list, speedup_cfg=None):
# Note: hack trick, may be updated in the future
if 'win' in sys.platform or 'Win'in sys.platform:
print('Skip test_speedup_integration on windows due to memory limit!')
return
Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2] Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2]
# for model_name in ['vgg16', 'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121', # for model_name in ['vgg16', 'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment