Unverified Commit e0b692c9 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

refactor of modelspeedup example (#2161)

parent d05488e0
...@@ -127,7 +127,7 @@ class NaiveModel(torch.nn.Module): ...@@ -127,7 +127,7 @@ class NaiveModel(torch.nn.Module):
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2) x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50) x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return x return x
......
import os
import argparse import argparse
import time import time
import torch import torch
...@@ -9,145 +10,89 @@ from nni.compression.speedup.torch import ModelSpeedup ...@@ -9,145 +10,89 @@ from nni.compression.speedup.torch import ModelSpeedup
from nni.compression.torch import apply_compression_results from nni.compression.torch import apply_compression_results
torch.manual_seed(0) torch.manual_seed(0)
use_mask = False use_mask = True
use_speedup = True
compare_results = True
def apoz_speedup(masks_file, model_checkpoint): config = {
device = torch.device('cuda') 'apoz': {
model = VGG(depth=16) 'model_name': 'vgg16',
model.to(device) 'device': 'cuda',
model.eval() 'input_shape': [64, 3, 32, 32],
'masks_file': './checkpoints/mask_vgg16_cifar10_apoz.pth'
dummy_input = torch.randn(64, 3, 32, 32) },
if use_mask: 'l1filter': {
apply_compression_results(model, masks_file) 'model_name': 'vgg16',
dummy_input = dummy_input.to(device) 'device': 'cuda',
start = time.time() 'input_shape': [64, 3, 32, 32],
for _ in range(32): 'masks_file': './checkpoints/mask_vgg16_cifar10_l1.pth'
out = model(dummy_input) },
#print(out.size(), out) 'fpgm': {
print('mask elapsed time: ', time.time() - start) 'model_name': 'naive',
return 'device': 'cpu',
else: 'input_shape': [64, 1, 28, 28],
#print("model before: ", model) 'masks_file': './checkpoints/mask_naive_mnist_fpgm.pth'
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file) },
m_speedup.speedup_model() 'slim': {
#print("model after: ", model) 'model_name': 'vgg19',
dummy_input = dummy_input.to(device) 'device': 'cuda',
start = time.time() 'input_shape': [64, 3, 32, 32],
for _ in range(32): 'masks_file': './checkpoints/mask_vgg19_cifar10_slim.pth' #'mask_vgg19_cifar10.pth'
out = model(dummy_input) }
#print(out.size(), out) }
print('speedup elapsed time: ', time.time() - start)
return
def l1filter_speedup(masks_file, model_checkpoint): def model_inference(config):
device = torch.device('cuda') masks_file = config['masks_file']
device = torch.device(config['device'])
if config['model_name'] == 'vgg16':
model = VGG(depth=16) model = VGG(depth=16)
model.to(device) elif config['model_name'] == 'vgg19':
model.eval()
dummy_input = torch.randn(64, 3, 32, 32)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return
def fpgm_speedup(masks_file, model_checkpoint):
from fpgm_torch_mnist import Mnist
device = torch.device('cpu')
model = Mnist()
model.to(device)
model.print_conv_filter_sparsity()
dummy_input = torch.randn(64, 1, 28, 28)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(40):
out = model(dummy_input)
print('mask elapsed time: ', time.time() - start)
#print(out.size(), out)
return
else:
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(40):
out = model(dummy_input)
print('speedup elapsed time: ', time.time() - start)
#print(out.size(), out)
return
def slim_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
model = VGG(depth=19) model = VGG(depth=19)
elif config['model_name'] == 'naive':
from model_prune_torch import NaiveModel
model = NaiveModel()
model.to(device) model.to(device)
model.eval() model.eval()
dummy_input = torch.randn(64, 3, 32, 32) dummy_input = torch.randn(config['input_shape']).to(device)
use_mask_out = use_speedup_out = None
# must run use_mask before use_speedup because use_speedup modify the model
if use_mask: if use_mask:
apply_compression_results(model, masks_file) apply_compression_results(model, masks_file, 'cpu' if config['device'] == 'cpu' else None)
dummy_input = dummy_input.to(device)
start = time.time() start = time.time()
for _ in range(32): for _ in range(32):
out = model(dummy_input) use_mask_out = model(dummy_input)
#print(out.size(), out) print('elapsed time when use mask: ', time.time() - start)
print('mask elapsed time: ', time.time() - start) if use_speedup:
return m_speedup = ModelSpeedup(model, dummy_input, masks_file,
else: 'cpu' if config['device'] == 'cpu' else None)
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model() m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time() start = time.time()
for _ in range(32): for _ in range(32):
out = model(dummy_input) use_speedup_out = model(dummy_input)
#print(out.size(), out) print('elapsed time when use speedup: ', time.time() - start)
print('speedup elapsed time: ', time.time() - start) if compare_results:
return if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07):
print('the outputs from use_mask and use_speedup are the same')
else:
raise RuntimeError('the outputs from use_mask and use_speedup are different')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser("speedup") parser = argparse.ArgumentParser("speedup")
parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example") parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example")
parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file") parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file")
parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model")
args = parser.parse_args() args = parser.parse_args()
if args.example_name == 'slim': if args.example_name != 'all':
if args.masks_file is None: if args.masks_file is not None:
args.masks_file = 'mask_vgg19_cifar10.pth' config[args.example_name]['masks_file'] = args.masks_file
slim_speedup(args.masks_file, args.model_checkpoint) if not os.path.exists(config[args.example_name]['masks_file']):
elif args.example_name == 'fpgm': msg = '{} does not exist! You should specify masks_file correctly, ' \
if args.masks_file is None: 'or use default one which is generated by model_prune_torch.py'
args.masks_file = 'mask.pth' raise RuntimeError(msg.format(config[args.example_name]['masks_file']))
fpgm_speedup(args.masks_file, args.model_checkpoint) model_inference(config[args.example_name])
elif args.example_name == 'l1filter':
if args.masks_file is None:
args.masks_file = 'mask_vgg16_cifar10.pth'
l1filter_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'apoz':
if args.masks_file is None:
args.masks_file = 'mask_vgg16_cifar10.pth'
apoz_speedup(args.masks_file, args.model_checkpoint)
else: else:
raise ValueError('unsupported example_name: {}'.format(args.example_name)) model_inference(config['fpgm'])
model_inference(config['slim'])
model_inference(config['l1filter'])
model_inference(config['apoz'])
...@@ -70,7 +70,7 @@ class ModelSpeedup: ...@@ -70,7 +70,7 @@ class ModelSpeedup:
This class is to speedup the model with provided weight mask This class is to speedup the model with provided weight mask
""" """
def __init__(self, model, dummy_input, masks_file): def __init__(self, model, dummy_input, masks_file, map_location=None):
""" """
Parameters Parameters
---------- ----------
...@@ -80,10 +80,12 @@ class ModelSpeedup: ...@@ -80,10 +80,12 @@ class ModelSpeedup:
The dummy input for ```jit.trace```, users should put it on right device before pass in The dummy input for ```jit.trace```, users should put it on right device before pass in
masks_file : str masks_file : str
The path of user provided mask file The path of user provided mask file
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
""" """
self.bound_model = model self.bound_model = model
self.dummy_input = dummy_input self.dummy_input = dummy_input
self.masks = torch.load(masks_file) self.masks = torch.load(masks_file, map_location)
self.is_training = model.training self.is_training = model.training
# to obtain forward graph, model should be in ```eval``` mode # to obtain forward graph, model should be in ```eval``` mode
if self.is_training: if self.is_training:
......
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
import logging import logging
import torch import torch
from .compressor import Pruner
logger = logging.getLogger('torch apply compression') logger = logging.getLogger('torch apply compression')
def apply_compression_results(model, masks_file): def apply_compression_results(model, masks_file, map_location=None):
""" """
Apply the masks from ```masks_file``` to the model Apply the masks from ```masks_file``` to the model
Note: this API is for inference, because it simply multiplies weights with
corresponding masks when this API is called.
Parameters Parameters
---------- ----------
...@@ -17,54 +18,12 @@ def apply_compression_results(model, masks_file): ...@@ -17,54 +18,12 @@ def apply_compression_results(model, masks_file):
The model to be compressed The model to be compressed
masks_file : str masks_file : str
The path of the mask file The path of the mask file
""" map_location : str
apply_comp = ApplyCompression(model, masks_file) the device on which masks are placed, same to map_location in ```torch.load```
apply_comp.compress() """
masks = torch.load(masks_file, map_location)
class ApplyCompression(Pruner): for name, module in model.named_modules():
""" if name in masks:
This class is not to generate masks, but applying existing masks module.weight.data = module.weight.data.mul_(masks[name]['weight'])
""" if hasattr(module, 'bias') and module.bias is not None and 'bias' in masks[name]:
module.bias.data = module.bias.data.mul_(masks[name]['bias'])
def __init__(self, model, masks_file): \ No newline at end of file
"""
Parameters
----------
model : torch.nn.module
Model to be masked
masks_file : str
The path of user provided mask file
"""
self.bound_model = model
self.masks = torch.load(masks_file)
for module_name in self.masks:
print('module_name: ', module_name)
config_list = self._build_config()
super().__init__(model, config_list)
def _build_config(self):
op_names = []
for module_name in self.masks:
op_names.append(module_name)
return [{'sparsity': 1, 'op_types': ['default', 'BatchNorm2d'], 'op_names': op_names}]
def calc_mask(self, layer, config, **kwargs):
"""
Directly return the corresponding mask
Parameters
----------
layer : LayerInfo
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
dict
Mask of the layer
"""
assert layer.name in self.masks
return self.masks[layer.name]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment