"test/vscode:/vscode.git/clone" did not exist on "00923759d33b1eec7229988b566421e3f10acbfb"
model_prune_torch.py 9.81 KB
Newer Older
chicm-ms's avatar
chicm-ms committed
1
2
3
4
5
6
7
8
9
10
11
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from models.cifar10.vgg import VGG
import nni
liuzhe-lz's avatar
liuzhe-lz committed
12
13
14
15
16
17
18
19
20
21
from nni.algorithms.compression.pytorch.pruning import (
    LevelPruner,
    SlimPruner,
    FPGMPruner,
    L1FilterPruner,
    L2FilterPruner,
    AGPPruner,
    ActivationMeanRankFilterPruner,
    ActivationAPoZRankFilterPruner
)
chicm-ms's avatar
chicm-ms committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35

prune_config = {
    'level': {
        'dataset_name': 'mnist',
        'model_name': 'naive',
        'pruner_class': LevelPruner,
        'config_list': [{
            'sparsity': 0.5,
            'op_types': ['default'],
        }]
    },
    'agp': {
        'dataset_name': 'mnist',
        'model_name': 'naive',
36
        'pruner_class': AGPPruner,
chicm-ms's avatar
chicm-ms committed
37
        'config_list': [{
chicm-ms's avatar
chicm-ms committed
38
            'initial_sparsity': 0.,
chicm-ms's avatar
chicm-ms committed
39
40
41
42
            'final_sparsity': 0.8,
            'start_epoch': 0,
            'end_epoch': 10,
            'frequency': 1,
43
            'op_types': ['Conv2d']
chicm-ms's avatar
chicm-ms committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        }]
    },
    'slim': {
        'dataset_name': 'cifar10',
        'model_name': 'vgg19',
        'pruner_class': SlimPruner,
        'config_list': [{
            'sparsity': 0.7,
            'op_types': ['BatchNorm2d']
        }]
    },
    'fpgm': {
        'dataset_name': 'mnist',
        'model_name': 'naive',
        'pruner_class': FPGMPruner,
59
        'config_list': [{
chicm-ms's avatar
chicm-ms committed
60
61
62
63
            'sparsity': 0.5,
            'op_types': ['Conv2d']
        }]
    },
chicm-ms's avatar
chicm-ms committed
64
    'l1filter': {
chicm-ms's avatar
chicm-ms committed
65
66
67
68
69
        'dataset_name': 'cifar10',
        'model_name': 'vgg16',
        'pruner_class': L1FilterPruner,
        'config_list': [{
            'sparsity': 0.5,
chicm-ms's avatar
chicm-ms committed
70
            'op_types': ['Conv2d'],
chicm-ms's avatar
chicm-ms committed
71
72
73
74
75
76
77
            'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
        }]
    },
    'mean_activation': {
        'dataset_name': 'cifar10',
        'model_name': 'vgg16',
        'pruner_class': ActivationMeanRankFilterPruner,
chicm-ms's avatar
chicm-ms committed
78
        'config_list': [{
chicm-ms's avatar
chicm-ms committed
79
            'sparsity': 0.5,
chicm-ms's avatar
chicm-ms committed
80
            'op_types': ['Conv2d'],
chicm-ms's avatar
chicm-ms committed
81
82
83
84
85
86
87
88
89
            'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
        }]
    },
    'apoz': {
        'dataset_name': 'cifar10',
        'model_name': 'vgg16',
        'pruner_class': ActivationAPoZRankFilterPruner,
        'config_list': [{
            'sparsity': 0.5,
90
            'op_types': ['Conv2d'],
chicm-ms's avatar
chicm-ms committed
91
92
93
94
95
            'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
        }]
    }
}

96

chicm-ms's avatar
chicm-ms committed
97
98
99
100
101
102
103
104
105
106
107
108
109
def get_data_loaders(dataset_name='mnist', batch_size=128):
    assert dataset_name in ['cifar10', 'mnist']

    if dataset_name == 'cifar10':
        ds_class = datasets.CIFAR10 if dataset_name == 'cifar10' else datasets.MNIST
        MEAN, STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    else:
        ds_class = datasets.MNIST
        MEAN, STD = (0.1307,), (0.3081,)

    train_loader = DataLoader(
        ds_class(
            './data', train=True, download=True,
110
111
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
chicm-ms's avatar
chicm-ms committed
112
113
114
115
116
117
        ),
        batch_size=batch_size, shuffle=True
    )
    test_loader = DataLoader(
        ds_class(
            './data', train=False, download=True,
118
119
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
chicm-ms's avatar
chicm-ms committed
120
121
122
123
124
125
        ),
        batch_size=batch_size, shuffle=False
    )

    return train_loader, test_loader

126

chicm-ms's avatar
chicm-ms committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class NaiveModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
        self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)
142
        x = x.view(x.size(0), -1)
chicm-ms's avatar
chicm-ms committed
143
144
145
146
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

147

chicm-ms's avatar
chicm-ms committed
148
149
150
151
152
153
154
155
156
157
def create_model(model_name='naive'):
    assert model_name in ['naive', 'vgg16', 'vgg19']

    if model_name == 'naive':
        return NaiveModel()
    elif model_name == 'vgg16':
        return VGG(16)
    else:
        return VGG(19)

158
159

def create_pruner(model, pruner_name, optimizer=None, dependency_aware=False, dummy_input=None):
chicm-ms's avatar
chicm-ms committed
160
161
    pruner_class = prune_config[pruner_name]['pruner_class']
    config_list = prune_config[pruner_name]['config_list']
162
163
164
165
166
167
168
169
    kw_args = {}
    if dependency_aware:
        print('Enable the dependency_aware mode')
        # note that, not all pruners support the dependency_aware mode
        kw_args['dependency_aware'] = True
        kw_args['dummy_input'] = dummy_input
    pruner = pruner_class(model, config_list, optimizer, **kw_args)
    return pruner
chicm-ms's avatar
chicm-ms committed
170
171
172
173
174
175
176
177
178
179
180

def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
181
182
183
            print('{:2.0f}%  Loss {}'.format(
                100 * batch_idx / len(train_loader), loss.item()))

chicm-ms's avatar
chicm-ms committed
184
185
186
187
188
189
190
191
192

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
193
194
            test_loss += F.cross_entropy(output,
                                         target, reduction='sum').item()
chicm-ms's avatar
chicm-ms committed
195
196
197
198
199
200
201
202
203
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    acc = 100 * correct / len(test_loader.dataset)

    print('Loss: {}  Accuracy: {}%)\n'.format(
        test_loss, acc))
    return acc

204

chicm-ms's avatar
chicm-ms committed
205
def main(args):
206
207
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
chicm-ms's avatar
chicm-ms committed
208
    os.makedirs(args.checkpoints_dir, exist_ok=True)
chicm-ms's avatar
chicm-ms committed
209
210
211
212

    model_name = prune_config[args.pruner_name]['model_name']
    dataset_name = prune_config[args.pruner_name]['dataset_name']
    train_loader, test_loader = get_data_loaders(dataset_name, args.batch_size)
213
214
    dummy_input, _ = next(iter(train_loader))
    dummy_input = dummy_input.to(device)
215
    model = create_model(model_name).to(device)
chicm-ms's avatar
chicm-ms committed
216
217
218
219
220
    if args.resume_from is not None and os.path.exists(args.resume_from):
        print('loading checkpoint {} ...'.format(args.resume_from))
        model.load_state_dict(torch.load(args.resume_from))
        test(model, device, test_loader)
    else:
221
222
        optimizer = torch.optim.SGD(
            model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
chicm-ms's avatar
chicm-ms committed
223
224
225
226
227
228
229
230
231
232
233
234
235
        if args.multi_gpu and torch.cuda.device_count():
            model = nn.DataParallel(model)

        print('start training')
        pretrain_model_path = os.path.join(
            args.checkpoints_dir, 'pretrain_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name))
        for epoch in range(args.pretrain_epochs):
            train(model, device, train_loader, optimizer)
            test(model, device, test_loader)
        torch.save(model.state_dict(), pretrain_model_path)

    print('start model pruning...')

236
237
238
239
    model_path = os.path.join(args.checkpoints_dir, 'pruned_{}_{}_{}.pth'.format(
        model_name, dataset_name, args.pruner_name))
    mask_path = os.path.join(args.checkpoints_dir, 'mask_{}_{}_{}.pth'.format(
        model_name, dataset_name, args.pruner_name))
chicm-ms's avatar
chicm-ms committed
240
241
242
243
244

    # pruner needs to be initialized from a model not wrapped by DataParallel
    if isinstance(model, nn.DataParallel):
        model = model.module

245
246
    optimizer_finetune = torch.optim.SGD(
        model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
chicm-ms's avatar
chicm-ms committed
247
248
    best_top1 = 0

249
250
    pruner = create_pruner(model, args.pruner_name,
                           optimizer_finetune, args.dependency_aware, dummy_input)
chicm-ms's avatar
chicm-ms committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    model = pruner.compress()

    if args.multi_gpu and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    for epoch in range(args.prune_epochs):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer_finetune)
        top1 = test(model, device, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path=model_path, mask_path=mask_path)

267

chicm-ms's avatar
chicm-ms committed
268
269
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
270
271
    parser.add_argument("--pruner_name", type=str,
                        default="level", help="pruner name")
chicm-ms's avatar
chicm-ms committed
272
    parser.add_argument("--batch_size", type=int, default=256)
273
274
275
276
277
278
279
280
281
282
283
284
    parser.add_argument("--pretrain_epochs", type=int,
                        default=10, help="training epochs before model pruning")
    parser.add_argument("--prune_epochs", type=int, default=10,
                        help="training epochs for model pruning")
    parser.add_argument("--checkpoints_dir", type=str,
                        default="./checkpoints", help="checkpoints directory")
    parser.add_argument("--resume_from", type=str,
                        default=None, help="pretrained model weights")
    parser.add_argument("--multi_gpu", action="store_true",
                        help="Use multiple GPUs for training")
    parser.add_argument("--dependency_aware", action="store_true", default=False,
                        help="If enable the dependency_aware mode for the pruner")
chicm-ms's avatar
chicm-ms committed
285
286
    args = parser.parse_args()
    main(args)