amc_search.py 5.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys
import argparse
import time

import torch
import torch.nn as nn

from nni.compression.torch import AMCPruner
from data import get_split_dataset
from utils import AverageMeter, accuracy

sys.path.append('../models')

def parse_args():
    parser = argparse.ArgumentParser(description='AMC search script')
    parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2'], help='model to prune')
    parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
    parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
    parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path')
    parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model')
    parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity')
    parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity')
    parser.add_argument('--ckpt_path', default=None, type=str, help='manual path of checkpoint')

    parser.add_argument('--train_episode', default=800, type=int, help='number of training episode')
    parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use')
    parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker')
    parser.add_argument('--job', default='train_export', type=str, choices=['train_export', 'export_only'],
                        help='search best pruning policy and export or just export model with searched policy')
    parser.add_argument('--export_path', default=None, type=str, help='path for exporting models')
    parser.add_argument('--searched_model_path', default=None, type=str, help='path for searched best wrapped model')

    return parser.parse_args()


def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
    if model == 'mobilenet' and dataset == 'imagenet':
        from mobilenet import MobileNet
        net = MobileNet(n_class=1000)
    elif model == 'mobilenetv2' and dataset == 'imagenet':
        from mobilenet_v2 import MobileNetV2
        net = MobileNetV2(n_class=1000)
    elif model == 'mobilenet' and dataset == 'cifar10':
        from mobilenet import MobileNet
        net = MobileNet(n_class=10)
    elif model == 'mobilenetv2' and dataset == 'cifar10':
        from mobilenet_v2 import MobileNetV2
        net = MobileNetV2(n_class=10)
    else:
        raise NotImplementedError
    if checkpoint_path:
        print('loading {}...'.format(checkpoint_path))
        sd = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        if 'state_dict' in sd:  # a checkpoint but not a state_dict
            sd = sd['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        net.load_state_dict(sd)

    if torch.cuda.is_available() and n_gpu > 0:
        net = net.cuda()
        if n_gpu > 1:
            net = torch.nn.DataParallel(net, range(n_gpu))

    return net

def init_data(args):
    # split the train set into train + val
    # for CIFAR, split 5k for val
    # for ImageNet, split 3k for val
    val_size = 5000 if 'cifar' in args.dataset else 3000
    train_loader, val_loader, _ = get_split_dataset(
        args.dataset, args.batch_size,
        args.n_worker, val_size,
        data_root=args.data_root,
        shuffle=False
    )  # same sampling
    return train_loader, val_loader

def validate(val_loader, model, verbose=False):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    criterion = nn.CrossEntropyLoss().cuda()
    # switch to evaluate mode
    model.eval()
    end = time.time()

    t1 = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.to(device)
            input_var = torch.autograd.Variable(input).to(device)
            target_var = torch.autograd.Variable(target).to(device)

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
    t2 = time.time()
    if verbose:
        print('* Test loss: %.3f    top1: %.3f    top5: %.3f    time: %.3f' %
              (losses.avg, top1.avg, top5.avg, t2 - t1))
    return top5.avg


if __name__ == "__main__":
    args = parse_args()

    device = torch.device('cuda') if torch.cuda.is_available() and args.n_gpu > 0 else torch.device('cpu')

    model = get_model_and_checkpoint(args.model_type, args.dataset, checkpoint_path=args.ckpt_path, n_gpu=args.n_gpu)
    _, val_loader = init_data(args)

    config_list = [{
        'op_types': ['Conv2d', 'Linear']
    }]
    pruner = AMCPruner(
        model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset,
        train_episode=args.train_episode, job=args.job, export_path=args.export_path,
        searched_model_path=args.searched_model_path,
        flops_ratio=args.flops_ratio, lbound=args.lbound, rbound=args.rbound)
    pruner.compress()