lottery_torch_mnist_fc.py 6.31 KB
Newer Older
1
2
import argparse
import copy
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import LotteryTicketPruner

class fc1(nn.Module):

    def __init__(self, num_classes=10):
        super(fc1, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(28*28, 300),
            nn.ReLU(inplace=True),
            nn.Linear(300, 100),
            nn.ReLU(inplace=True),
            nn.Linear(100, num_classes),
        )

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def train(model, train_loader, optimizer, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
31
    for imgs, targets in train_loader:
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        optimizer.zero_grad()
        imgs, targets = imgs.to(device), targets.to(device)
        output = model(imgs)
        train_loss = criterion(output, targets)
        train_loss.backward()
        optimizer.step()
    return train_loss.item()

def test(model, test_loader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy


if __name__ == '__main__':
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    """
    THE LOTTERY TICKET HYPOTHESIS: FINDING SPARSE, TRAINABLE NEURAL NETWORKS (https://arxiv.org/pdf/1803.03635.pdf)

    The Lottery Ticket Hypothesis. A randomly-initialized, dense neural network contains a subnetwork that is
    initialized such that—when trained in isolation—it can match the test accuracy of the original network after
    training for at most the same number of iterations.

    Identifying winning tickets. We identify a winning ticket by training a network and pruning its
    smallest-magnitude weights. The remaining, unpruned connections constitute the architecture of the
    winning ticket. Unique to our work, each unpruned connection’s value is then reset to its initialization
    from original network before it was trained. This forms our central experiment:
        1. Randomly initialize a neural network f(x; θ0) (where θ0 ∼ Dθ).
        2. Train the network for j iterations, arriving at parameters θj .
        3. Prune p% of the parameters in θj , creating a mask m.
        4. Reset the remaining parameters to their values in θ0, creating the winning ticket f(x; mθ0).
    As described, this pruning approach is one-shot: the network is trained once, p% of weights are
    pruned, and the surviving weights are reset. However, in this paper, we focus on iterative pruning,
    which repeatedly trains, prunes, and resets the network over n rounds; each round prunes p**(1/n) % of
    the weights that survive the previous round. Our results show that iterative pruning finds winning tickets
    that match the accuracy of the original network at smaller sizes than does one-shot pruning.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_epochs", type=int, default=10, help="training epochs")
    args = parser.parse_args()

83
84
85
86
87
88
89
90
91
92
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    testdataset = datasets.MNIST('./data', train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=0, drop_last=False)
    test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=0, drop_last=True)

    model = fc1().to("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
    criterion = nn.CrossEntropyLoss()

93
94
95
96
97
98
99
100
101
102
103
104
105
106
    # Record the random intialized model weights
    orig_state = copy.deepcopy(model.state_dict())

    # train the model to get unpruned metrics
    for epoch in range(args.train_epochs):
        train(model, train_loader, optimizer, criterion)
    orig_accuracy = test(model, test_loader, criterion)
    print('unpruned model accuracy: {}'.format(orig_accuracy))

    # reset model weights and optimizer for pruning
    model.load_state_dict(orig_state)
    optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)

    # Prune the model to find a winning ticket
107
    configure_list = [{
108
        'prune_iterations': 5,
109
110
111
112
113
114
        'sparsity': 0.96,
        'op_types': ['default']
    }]
    pruner = LotteryTicketPruner(model, configure_list, optimizer)
    pruner.compress()

115
116
    best_accuracy = 0.
    best_state_dict = None
117

118
119
120
121
    for i in pruner.get_prune_iterations():
        pruner.prune_iteration_start()
        loss = 0
        accuracy = 0
122
        for epoch in range(args.train_epochs):
123
124
125
            loss = train(model, train_loader, optimizer, criterion)
            accuracy = test(model, test_loader, criterion)
            print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
126
127
128
129
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                # state dict of weights and masks
                best_state_dict = copy.deepcopy(model.state_dict())
130
        print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
131
132
133
134
135
136
137
138
139
140

    if best_accuracy > orig_accuracy:
        # load weights and masks
        pruner.bound_model.load_state_dict(best_state_dict)
        # reset weights to original untrained model and keep masks unchanged to export winning ticket
        pruner.load_model_state_dict(orig_state)
        pruner.export_model('model_winning_ticket.pth', 'mask_winning_ticket.pth')
        print('winning ticket has been saved: model_winning_ticket.pth, mask_winning_ticket.pth')
    else:
        print('winning ticket is not found in this run, you can run it again.')