amc_search.py 5.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys
import argparse
import time

import torch
import torch.nn as nn
chicm-ms's avatar
chicm-ms committed
10
from torchvision.models import resnet
liuzhe-lz's avatar
liuzhe-lz committed
11
from nni.algorithms.compression.pytorch.pruning import AMCPruner
12
13
14
15
16
17
18
from data import get_split_dataset
from utils import AverageMeter, accuracy

sys.path.append('../models')

def parse_args():
    parser = argparse.ArgumentParser(description='AMC search script')
chicm-ms's avatar
chicm-ms committed
19
20
    parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2', 'resnet18', 'resnet34', 'resnet50'],
        help='model to prune')
21
22
    parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
    parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
chicm-ms's avatar
chicm-ms committed
23
    parser.add_argument('--data_root', default='./data', type=str, help='dataset path')
24
25
26
27
28
29
30
31
    parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model')
    parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity')
    parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity')
    parser.add_argument('--ckpt_path', default=None, type=str, help='manual path of checkpoint')

    parser.add_argument('--train_episode', default=800, type=int, help='number of training episode')
    parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use')
    parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker')
chicm-ms's avatar
chicm-ms committed
32
    parser.add_argument('--suffix', default=None, type=str, help='suffix of auto-generated log directory')
33
34
35
36
37

    return parser.parse_args()


def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
chicm-ms's avatar
chicm-ms committed
38
39
40
41
42
43
44
45
    if dataset == 'imagenet':
        n_class = 1000
    elif dataset == 'cifar10':
        n_class = 10
    else:
        raise ValueError('unsupported dataset')

    if model == 'mobilenet':
46
        from mobilenet import MobileNet
chicm-ms's avatar
chicm-ms committed
47
48
        net = MobileNet(n_class=n_class)
    elif model == 'mobilenetv2':
49
        from mobilenet_v2 import MobileNetV2
chicm-ms's avatar
chicm-ms committed
50
51
52
53
54
        net = MobileNetV2(n_class=n_class)
    elif model.startswith('resnet'):
        net = resnet.__dict__[model](pretrained=True)
        in_features = net.fc.in_features
        net.fc = nn.Linear(in_features, n_class)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    else:
        raise NotImplementedError
    if checkpoint_path:
        print('loading {}...'.format(checkpoint_path))
        sd = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        if 'state_dict' in sd:  # a checkpoint but not a state_dict
            sd = sd['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        net.load_state_dict(sd)

    if torch.cuda.is_available() and n_gpu > 0:
        net = net.cuda()
        if n_gpu > 1:
            net = torch.nn.DataParallel(net, range(n_gpu))

    return net

def init_data(args):
    # split the train set into train + val
    # for CIFAR, split 5k for val
    # for ImageNet, split 3k for val
    val_size = 5000 if 'cifar' in args.dataset else 3000
    train_loader, val_loader, _ = get_split_dataset(
        args.dataset, args.batch_size,
        args.n_worker, val_size,
        data_root=args.data_root,
        shuffle=False
    )  # same sampling
    return train_loader, val_loader

def validate(val_loader, model, verbose=False):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    criterion = nn.CrossEntropyLoss().cuda()
    # switch to evaluate mode
    model.eval()
    end = time.time()

    t1 = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.to(device)
            input_var = torch.autograd.Variable(input).to(device)
            target_var = torch.autograd.Variable(target).to(device)

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
    t2 = time.time()
    if verbose:
        print('* Test loss: %.3f    top1: %.3f    top5: %.3f    time: %.3f' %
              (losses.avg, top1.avg, top5.avg, t2 - t1))
    return top5.avg


if __name__ == "__main__":
    args = parse_args()

    device = torch.device('cuda') if torch.cuda.is_available() and args.n_gpu > 0 else torch.device('cpu')

    model = get_model_and_checkpoint(args.model_type, args.dataset, checkpoint_path=args.ckpt_path, n_gpu=args.n_gpu)
    _, val_loader = init_data(args)

    config_list = [{
        'op_types': ['Conv2d', 'Linear']
    }]
    pruner = AMCPruner(
        model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset,
chicm-ms's avatar
chicm-ms committed
136
137
        train_episode=args.train_episode, flops_ratio=args.flops_ratio, lbound=args.lbound,
        rbound=args.rbound, suffix=args.suffix)
138
    pruner.compress()