Unverified Commit e0b692c9 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

refactor of modelspeedup example (#2161)

parent d05488e0
......@@ -127,7 +127,7 @@ class NaiveModel(torch.nn.Module):
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
......
import os
import argparse
import time
import torch
......@@ -9,145 +10,89 @@ from nni.compression.speedup.torch import ModelSpeedup
from nni.compression.torch import apply_compression_results
torch.manual_seed(0)
use_mask = False
use_mask = True
use_speedup = True
compare_results = True
def apoz_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
model = VGG(depth=16)
model.to(device)
model.eval()
dummy_input = torch.randn(64, 3, 32, 32)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return
config = {
'apoz': {
'model_name': 'vgg16',
'device': 'cuda',
'input_shape': [64, 3, 32, 32],
'masks_file': './checkpoints/mask_vgg16_cifar10_apoz.pth'
},
'l1filter': {
'model_name': 'vgg16',
'device': 'cuda',
'input_shape': [64, 3, 32, 32],
'masks_file': './checkpoints/mask_vgg16_cifar10_l1.pth'
},
'fpgm': {
'model_name': 'naive',
'device': 'cpu',
'input_shape': [64, 1, 28, 28],
'masks_file': './checkpoints/mask_naive_mnist_fpgm.pth'
},
'slim': {
'model_name': 'vgg19',
'device': 'cuda',
'input_shape': [64, 3, 32, 32],
'masks_file': './checkpoints/mask_vgg19_cifar10_slim.pth' #'mask_vgg19_cifar10.pth'
}
}
def l1filter_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
model = VGG(depth=16)
def model_inference(config):
masks_file = config['masks_file']
device = torch.device(config['device'])
if config['model_name'] == 'vgg16':
model = VGG(depth=16)
elif config['model_name'] == 'vgg19':
model = VGG(depth=19)
elif config['model_name'] == 'naive':
from model_prune_torch import NaiveModel
model = NaiveModel()
model.to(device)
model.eval()
dummy_input = torch.randn(64, 3, 32, 32)
dummy_input = torch.randn(config['input_shape']).to(device)
use_mask_out = use_speedup_out = None
# must run use_mask before use_speedup because use_speedup modify the model
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
apply_compression_results(model, masks_file, 'cpu' if config['device'] == 'cpu' else None)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
use_mask_out = model(dummy_input)
print('elapsed time when use mask: ', time.time() - start)
if use_speedup:
m_speedup = ModelSpeedup(model, dummy_input, masks_file,
'cpu' if config['device'] == 'cpu' else None)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return
def fpgm_speedup(masks_file, model_checkpoint):
from fpgm_torch_mnist import Mnist
device = torch.device('cpu')
model = Mnist()
model.to(device)
model.print_conv_filter_sparsity()
dummy_input = torch.randn(64, 1, 28, 28)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(40):
out = model(dummy_input)
print('mask elapsed time: ', time.time() - start)
#print(out.size(), out)
return
else:
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(40):
out = model(dummy_input)
print('speedup elapsed time: ', time.time() - start)
#print(out.size(), out)
return
def slim_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
model = VGG(depth=19)
model.to(device)
model.eval()
dummy_input = torch.randn(64, 3, 32, 32)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return
use_speedup_out = model(dummy_input)
print('elapsed time when use speedup: ', time.time() - start)
if compare_results:
if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07):
print('the outputs from use_mask and use_speedup are the same')
else:
raise RuntimeError('the outputs from use_mask and use_speedup are different')
if __name__ == '__main__':
parser = argparse.ArgumentParser("speedup")
parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example")
parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file")
parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model")
args = parser.parse_args()
if args.example_name == 'slim':
if args.masks_file is None:
args.masks_file = 'mask_vgg19_cifar10.pth'
slim_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'fpgm':
if args.masks_file is None:
args.masks_file = 'mask.pth'
fpgm_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'l1filter':
if args.masks_file is None:
args.masks_file = 'mask_vgg16_cifar10.pth'
l1filter_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'apoz':
if args.masks_file is None:
args.masks_file = 'mask_vgg16_cifar10.pth'
apoz_speedup(args.masks_file, args.model_checkpoint)
if args.example_name != 'all':
if args.masks_file is not None:
config[args.example_name]['masks_file'] = args.masks_file
if not os.path.exists(config[args.example_name]['masks_file']):
msg = '{} does not exist! You should specify masks_file correctly, ' \
'or use default one which is generated by model_prune_torch.py'
raise RuntimeError(msg.format(config[args.example_name]['masks_file']))
model_inference(config[args.example_name])
else:
raise ValueError('unsupported example_name: {}'.format(args.example_name))
model_inference(config['fpgm'])
model_inference(config['slim'])
model_inference(config['l1filter'])
model_inference(config['apoz'])
......@@ -70,7 +70,7 @@ class ModelSpeedup:
This class is to speedup the model with provided weight mask
"""
def __init__(self, model, dummy_input, masks_file):
def __init__(self, model, dummy_input, masks_file, map_location=None):
"""
Parameters
----------
......@@ -80,10 +80,12 @@ class ModelSpeedup:
The dummy input for ```jit.trace```, users should put it on right device before pass in
masks_file : str
The path of user provided mask file
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
"""
self.bound_model = model
self.dummy_input = dummy_input
self.masks = torch.load(masks_file)
self.masks = torch.load(masks_file, map_location)
self.is_training = model.training
# to obtain forward graph, model should be in ```eval``` mode
if self.is_training:
......
......@@ -3,13 +3,14 @@
import logging
import torch
from .compressor import Pruner
logger = logging.getLogger('torch apply compression')
def apply_compression_results(model, masks_file):
def apply_compression_results(model, masks_file, map_location=None):
"""
Apply the masks from ```masks_file``` to the model
Note: this API is for inference, because it simply multiplies weights with
corresponding masks when this API is called.
Parameters
----------
......@@ -17,54 +18,12 @@ def apply_compression_results(model, masks_file):
The model to be compressed
masks_file : str
The path of the mask file
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
"""
apply_comp = ApplyCompression(model, masks_file)
apply_comp.compress()
class ApplyCompression(Pruner):
"""
This class is not to generate masks, but applying existing masks
"""
def __init__(self, model, masks_file):
"""
Parameters
----------
model : torch.nn.module
Model to be masked
masks_file : str
The path of user provided mask file
"""
self.bound_model = model
self.masks = torch.load(masks_file)
for module_name in self.masks:
print('module_name: ', module_name)
config_list = self._build_config()
super().__init__(model, config_list)
def _build_config(self):
op_names = []
for module_name in self.masks:
op_names.append(module_name)
return [{'sparsity': 1, 'op_types': ['default', 'BatchNorm2d'], 'op_names': op_names}]
def calc_mask(self, layer, config, **kwargs):
"""
Directly return the corresponding mask
Parameters
----------
layer : LayerInfo
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
dict
Mask of the layer
"""
assert layer.name in self.masks
return self.masks[layer.name]
masks = torch.load(masks_file, map_location)
for name, module in model.named_modules():
if name in masks:
module.weight.data = module.weight.data.mul_(masks[name]['weight'])
if hasattr(module, 'bias') and module.bias is not None and 'bias' in masks[name]:
module.bias.data = module.bias.data.mul_(masks[name]['bias'])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment