model_speedup.py 3.6 KB
Newer Older
1
import os
QuanluZhang's avatar
QuanluZhang committed
2
3
4
5
6
7
8
9
10
11
12
import argparse
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from models.cifar10.vgg import VGG
from nni.compression.speedup.torch import ModelSpeedup
from nni.compression.torch import apply_compression_results

torch.manual_seed(0)
13
14
15
use_mask = True
use_speedup = True
compare_results = True
QuanluZhang's avatar
QuanluZhang committed
16

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
config = {
    'apoz': {
        'model_name': 'vgg16',
        'device': 'cuda',
        'input_shape': [64, 3, 32, 32],
        'masks_file': './checkpoints/mask_vgg16_cifar10_apoz.pth'
    },
    'l1filter': {
        'model_name': 'vgg16',
        'device': 'cuda',
        'input_shape': [64, 3, 32, 32],
        'masks_file': './checkpoints/mask_vgg16_cifar10_l1.pth'
    },
    'fpgm': {
        'model_name': 'naive',
        'device': 'cpu',
        'input_shape': [64, 1, 28, 28],
        'masks_file': './checkpoints/mask_naive_mnist_fpgm.pth'
    },
    'slim': {
        'model_name': 'vgg19',
        'device': 'cuda',
        'input_shape': [64, 3, 32, 32],
        'masks_file': './checkpoints/mask_vgg19_cifar10_slim.pth' #'mask_vgg19_cifar10.pth'
    }
}
QuanluZhang's avatar
QuanluZhang committed
43

44
45
46
47
48
49
50
51
52
53
def model_inference(config):
    masks_file = config['masks_file']
    device = torch.device(config['device'])
    if config['model_name'] == 'vgg16':
        model = VGG(depth=16)
    elif config['model_name'] == 'vgg19':
        model = VGG(depth=19)
    elif config['model_name'] == 'naive':
        from model_prune_torch import NaiveModel
        model = NaiveModel()
QuanluZhang's avatar
QuanluZhang committed
54
55
56
    model.to(device)
    model.eval()

57
58
59
    dummy_input = torch.randn(config['input_shape']).to(device)
    use_mask_out = use_speedup_out = None
    # must run use_mask before use_speedup because use_speedup modify the model
QuanluZhang's avatar
QuanluZhang committed
60
    if use_mask:
61
        apply_compression_results(model, masks_file, 'cpu' if config['device'] == 'cpu' else None)
QuanluZhang's avatar
QuanluZhang committed
62
63
        start = time.time()
        for _ in range(32):
64
65
66
67
68
            use_mask_out = model(dummy_input)
        print('elapsed time when use mask: ', time.time() - start)
    if use_speedup:
        m_speedup = ModelSpeedup(model, dummy_input, masks_file,
                                 'cpu' if config['device'] == 'cpu' else None)
QuanluZhang's avatar
QuanluZhang committed
69
70
71
        m_speedup.speedup_model()
        start = time.time()
        for _ in range(32):
72
73
74
75
76
77
78
            use_speedup_out = model(dummy_input)
        print('elapsed time when use speedup: ', time.time() - start)
    if compare_results:
        if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07):
            print('the outputs from use_mask and use_speedup are the same')
        else:
            raise RuntimeError('the outputs from use_mask and use_speedup are different')
QuanluZhang's avatar
QuanluZhang committed
79
80
81
82
83
84

if __name__ == '__main__':
    parser = argparse.ArgumentParser("speedup")
    parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example")
    parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file")
    args = parser.parse_args()
85
86
87
88
89
90
91
92
93

    if args.example_name != 'all':
        if args.masks_file is not None:
            config[args.example_name]['masks_file'] = args.masks_file
        if not os.path.exists(config[args.example_name]['masks_file']):
            msg = '{} does not exist! You should specify masks_file correctly, ' \
                  'or use default one which is generated by model_prune_torch.py'
            raise RuntimeError(msg.format(config[args.example_name]['masks_file']))
        model_inference(config[args.example_name])
QuanluZhang's avatar
QuanluZhang committed
94
    else:
95
96
97
98
        model_inference(config['fpgm'])
        model_inference(config['slim'])
        model_inference(config['l1filter'])
        model_inference(config['apoz'])