slim_torch_cifar10.py 5.49 KB
Newer Older
Tang Lang's avatar
Tang Lang committed
1
import math
Cjkkkk's avatar
Cjkkkk committed
2
import argparse
Tang Lang's avatar
Tang Lang committed
3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import SlimPruner
Tang Lang's avatar
Tang Lang committed
8
from models.cifar10.vgg import VGG
Tang Lang's avatar
Tang Lang committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

def updateBN(model):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.weight.grad.data.add_(0.0001 * torch.sign(m.weight.data))  # L1


def train(model, device, train_loader, optimizer, sparse_bn=False):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        # L1 regularization on BN layer
        if sparse_bn:
            updateBN(model)
        optimizer.step()
        if batch_idx % 100 == 0:
            print('{:2.0f}%  Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    acc = 100 * correct / len(test_loader.dataset)

    print('Loss: {}  Accuracy: {}%)\n'.format(
        test_loss, acc))
    return acc


def main():
Cjkkkk's avatar
Cjkkkk committed
52
53
54
55
56
57
58
    parser = argparse.ArgumentParser("multiple gpu with pruning")
    parser.add_argument("--epochs", type=int, default=160)
    parser.add_argument("--retrain", default=False, action="store_true")
    parser.add_argument("--parallel", default=False, action="store_true")

    args = parser.parse_args()

Tang Lang's avatar
Tang Lang committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    torch.manual_seed(0)
    device = torch.device('cuda')
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=True, download=True,
                         transform=transforms.Compose([
                             transforms.Pad(4),
                             transforms.RandomCrop(32),
                             transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                         ])),
        batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),
        batch_size=200, shuffle=False)

Tang Lang's avatar
Tang Lang committed
78
    model = VGG(depth=19)
Tang Lang's avatar
Tang Lang committed
79
80
    model.to(device)
    # Train the base VGG-19 model
Cjkkkk's avatar
Cjkkkk committed
81
82
83
84
85
86
87
88
89
90
91
92
    if args.retrain:
        print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
        epochs = args.epochs
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
        for epoch in range(epochs):
            if epoch in [epochs * 0.5, epochs * 0.75]:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.1
            print("epoch {}".format(epoch))
            train(model, device, train_loader, optimizer, True)
            test(model, device, test_loader)
        torch.save(model.state_dict(), 'vgg19_cifar10.pth')
Tang Lang's avatar
Tang Lang committed
93
94
95
96
97
98
99
100
101
102
103
104

    # Test base model accuracy
    print('=' * 10 + 'Test the original model' + '=' * 10)
    model.load_state_dict(torch.load('vgg19_cifar10.pth'))
    test(model, device, test_loader)
    # top1 = 93.60%

    # Pruning Configuration, in paper 'Learning efficient convolutional networks through network slimming',
    configure_list = [{
        'sparsity': 0.7,
        'op_types': ['BatchNorm2d'],
    }]
Cjkkkk's avatar
Cjkkkk committed
105
    
Tang Lang's avatar
Tang Lang committed
106
107
108
109
    # Prune model and test accuracy without fine tuning.
    print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10)
    pruner = SlimPruner(model, configure_list)
    model = pruner.compress()
Cjkkkk's avatar
Cjkkkk committed
110
111
112
113
114
115
116
117
    if args.parallel:
        if torch.cuda.device_count() > 1:
            print("use {} gpus for pruning".format(torch.cuda.device_count()))
            model = nn.DataParallel(model)
            # model = nn.DataParallel(model, device_ids=[0, 1])
        else:
            print("only detect 1 gpu, fall back")
    model.to(device)
Tang Lang's avatar
Tang Lang committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    # Fine tune the pruned model for 40 epochs and test accuracy
    print('=' * 10 + 'Fine tuning' + '=' * 10)
    optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
    best_top1 = 0
    for epoch in range(40):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer_finetune)
        top1 = test(model, device, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path='pruned_vgg19_cifar10.pth', mask_path='mask_vgg19_cifar10.pth')

    # Test the exported model
    print('=' * 10 + 'Test the export pruned model after fine tune' + '=' * 10)
Tang Lang's avatar
Tang Lang committed
135
    new_model = VGG(depth=19)
Tang Lang's avatar
Tang Lang committed
136
137
138
    new_model.to(device)
    new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth'))
    test(new_model, device, test_loader)
Tang Lang's avatar
Tang Lang committed
139
    # top1 = 93.74%
Tang Lang's avatar
Tang Lang committed
140
141
142
143


if __name__ == '__main__':
    main()