lottery_torch_mnist_fc.py 3.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import LotteryTicketPruner

class fc1(nn.Module):

    def __init__(self, num_classes=10):
        super(fc1, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(28*28, 300),
            nn.ReLU(inplace=True),
            nn.Linear(300, 100),
            nn.ReLU(inplace=True),
            nn.Linear(100, num_classes),
        )

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def train(model, train_loader, optimizer, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
29
    for imgs, targets in train_loader:
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        optimizer.zero_grad()
        imgs, targets = imgs.to(device), targets.to(device)
        output = model(imgs)
        train_loss = criterion(output, targets)
        train_loss.backward()
        optimizer.step()
    return train_loss.item()

def test(model, test_loader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy


if __name__ == '__main__':
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    testdataset = datasets.MNIST('./data', train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=0, drop_last=False)
    test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=0, drop_last=True)

    model = fc1().to("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
    criterion = nn.CrossEntropyLoss()

    configure_list = [{
67
        'prune_iterations': 5,
68
69
70
71
72
73
74
75
76
77
        'sparsity': 0.96,
        'op_types': ['default']
    }]
    pruner = LotteryTicketPruner(model, configure_list, optimizer)
    pruner.compress()

    for i in pruner.get_prune_iterations():
        pruner.prune_iteration_start()
        loss = 0
        accuracy = 0
78
        for epoch in range(10):
79
80
81
82
83
            loss = train(model, train_loader, optimizer, criterion)
            accuracy = test(model, test_loader, criterion)
            print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
        print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
    pruner.export_model('model.pth', 'mask.pth')