"examples/gemm_sm100/gemm_mma.py" did not exist on "25a50f1a37e7b2eea3479fb9a78c8847883f2552"
model_speedup.py 3.47 KB
Newer Older
1
import os
QuanluZhang's avatar
QuanluZhang committed
2
3
4
5
6
7
8
import argparse
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from models.cifar10.vgg import VGG
liuzhe-lz's avatar
liuzhe-lz committed
9
from nni.compression.pytorch import apply_compression_results, ModelSpeedup
QuanluZhang's avatar
QuanluZhang committed
10
11

torch.manual_seed(0)
12
13
14
use_mask = True
use_speedup = True
compare_results = True
QuanluZhang's avatar
QuanluZhang committed
15

16
17
18
19
20
21
22
23
24
config = {
    'apoz': {
        'model_name': 'vgg16',
        'input_shape': [64, 3, 32, 32],
        'masks_file': './checkpoints/mask_vgg16_cifar10_apoz.pth'
    },
    'l1filter': {
        'model_name': 'vgg16',
        'input_shape': [64, 3, 32, 32],
chicm-ms's avatar
chicm-ms committed
25
        'masks_file': './checkpoints/mask_vgg16_cifar10_l1filter.pth'
26
27
28
29
30
31
32
33
34
35
36
37
    },
    'fpgm': {
        'model_name': 'naive',
        'input_shape': [64, 1, 28, 28],
        'masks_file': './checkpoints/mask_naive_mnist_fpgm.pth'
    },
    'slim': {
        'model_name': 'vgg19',
        'input_shape': [64, 3, 32, 32],
        'masks_file': './checkpoints/mask_vgg19_cifar10_slim.pth' #'mask_vgg19_cifar10.pth'
    }
}
QuanluZhang's avatar
QuanluZhang committed
38

39
40
def model_inference(config):
    masks_file = config['masks_file']
41
42
43
44
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
        
    # device = torch.device(config['device'])
45
46
47
48
49
50
51
    if config['model_name'] == 'vgg16':
        model = VGG(depth=16)
    elif config['model_name'] == 'vgg19':
        model = VGG(depth=19)
    elif config['model_name'] == 'naive':
        from model_prune_torch import NaiveModel
        model = NaiveModel()
QuanluZhang's avatar
QuanluZhang committed
52
53
54
    model.to(device)
    model.eval()

55
56
57
    dummy_input = torch.randn(config['input_shape']).to(device)
    use_mask_out = use_speedup_out = None
    # must run use_mask before use_speedup because use_speedup modify the model
QuanluZhang's avatar
QuanluZhang committed
58
    if use_mask:
59
        apply_compression_results(model, masks_file, device)
QuanluZhang's avatar
QuanluZhang committed
60
61
        start = time.time()
        for _ in range(32):
62
63
64
            use_mask_out = model(dummy_input)
        print('elapsed time when use mask: ', time.time() - start)
    if use_speedup:
65
        m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
QuanluZhang's avatar
QuanluZhang committed
66
67
68
        m_speedup.speedup_model()
        start = time.time()
        for _ in range(32):
69
70
71
72
73
74
75
            use_speedup_out = model(dummy_input)
        print('elapsed time when use speedup: ', time.time() - start)
    if compare_results:
        if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07):
            print('the outputs from use_mask and use_speedup are the same')
        else:
            raise RuntimeError('the outputs from use_mask and use_speedup are different')
QuanluZhang's avatar
QuanluZhang committed
76
77
78
79
80
81

if __name__ == '__main__':
    parser = argparse.ArgumentParser("speedup")
    parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example")
    parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file")
    args = parser.parse_args()
82
83
84
85
86
87
88
89
90

    if args.example_name != 'all':
        if args.masks_file is not None:
            config[args.example_name]['masks_file'] = args.masks_file
        if not os.path.exists(config[args.example_name]['masks_file']):
            msg = '{} does not exist! You should specify masks_file correctly, ' \
                  'or use default one which is generated by model_prune_torch.py'
            raise RuntimeError(msg.format(config[args.example_name]['masks_file']))
        model_inference(config[args.example_name])
QuanluZhang's avatar
QuanluZhang committed
91
    else:
92
93
94
95
        model_inference(config['fpgm'])
        model_inference(config['slim'])
        model_inference(config['l1filter'])
        model_inference(config['apoz'])