multi_gpu.py 3.58 KB
Newer Older
Cjkkkk's avatar
Cjkkkk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import SlimPruner

class fc1(nn.Module):

    def __init__(self, num_classes=10):
        super(fc1, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU(inplace=True)
        
        
        self.linear1 = nn.Linear(32*28*28, 300)
        self.relu2 = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(300, 100)
        self.relu3 = nn.ReLU(inplace=True)
        self.linear3 = nn.Linear(100, num_classes)
        

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = torch.flatten(x,1)
        x = self.linear1(x)
        x = self.relu2(x)
        x = self.linear2(x)
        x = self.relu3(x)
        x = self.linear3(x)
        return x

def train(model, train_loader, optimizer, criterion, device):
    model.train()
    for imgs, targets in train_loader:
        optimizer.zero_grad()
        imgs, targets = imgs.to(device), targets.to(device)
        output = model(imgs)
        train_loss = criterion(output, targets)
        train_loss.backward()
        optimizer.step()
    return train_loss.item()

def test(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy


if __name__ == '__main__':
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    testdataset = datasets.MNIST('./data', train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
    test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)

72
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Cjkkkk's avatar
Cjkkkk committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    model = fc1()
    
    criterion = nn.CrossEntropyLoss()

    configure_list = [{
        'prune_iterations': 5,
        'sparsity': 0.86,
        'op_types': ['BatchNorm2d']
    }]
    pruner = SlimPruner(model, configure_list)
    pruner.compress()

    if torch.cuda.device_count()>1:
        model = nn.DataParallel(model)
  
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)  
    for name, par in model.named_parameters():
        print(name)
    # for i in pruner.get_prune_iterations():
    #     pruner.prune_iteration_start()
    loss = 0
    accuracy = 0
    for epoch in range(10):
        loss = train(model, train_loader, optimizer, criterion, device)
        accuracy = test(model, test_loader, criterion, device)
        print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
            # print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
    pruner.export_model('model.pth', 'mask.pth')