mnist.py 5.6 KB
Newer Older
1
2
import argparse
import time
3

4
import networkx as nx
5
import numpy as np
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
9
10
11
from coarsening import coarsen
from coordinate import get_coordinates, z2polar
from grid_graph import grid_graph
12
13
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
14
15
16

import dgl
from dgl.data import load_data, register_data_args
17
18
19
20
from dgl.nn.pytorch.conv import ChebConv, GMMConv
from dgl.nn.pytorch.glob import MaxPooling

argparser = argparse.ArgumentParser("MNIST")
21
22
23
24
25
26
27
argparser.add_argument(
    "--gpu", type=int, default=-1, help="gpu id, use cpu if set to -1"
)
argparser.add_argument(
    "--model", type=str, default="chebnet", help="model to use, chebnet/monet"
)
argparser.add_argument("--batch-size", type=int, default=100, help="batch size")
28
29
30
31
args = argparser.parse_args()

grid_side = 28
number_edges = 8
32
metric = "euclidean"
33
34
35
36
37

A = grid_graph(28, 8, metric)

coarsening_levels = 4
L, perm = coarsen(A, coarsening_levels)
38
g_arr = [dgl.from_scipy(csr) for csr in L]
39
40

coordinate_arr = get_coordinates(g_arr, grid_side, coarsening_levels, perm)
41
42
43
44
45
46
47
48
str_to_torch_dtype = {
    "float16": torch.half,
    "float32": torch.float32,
    "float64": torch.float64,
}
coordinate_arr = [
    coord.to(dtype=str_to_torch_dtype[str(A.dtype)]) for coord in coordinate_arr
]
49
for g, coordinate_arr in zip(g_arr, coordinate_arr):
50
    g.ndata["xy"] = coordinate_arr
51
52
    g.apply_edges(z2polar)

53

54
55
56
57
58
def batcher(batch):
    g_batch = [[] for _ in range(coarsening_levels + 1)]
    x_batch = []
    y_batch = []
    for x, y in batch:
59
        x = torch.cat([x.view(-1), x.new_zeros(len(perm) - 28**2)], 0)
60
61
62
63
64
65
66
67
68
69
70
71
        x = x[perm]
        x_batch.append(x)
        y_batch.append(y)
        for i in range(coarsening_levels + 1):
            g_batch[i].append(g_arr[i])

    x_batch = torch.cat(x_batch).unsqueeze(-1)
    y_batch = torch.LongTensor(y_batch)
    g_batch = [dgl.batch(g) for g in g_batch]
    return g_batch, x_batch, y_batch


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
trainset = datasets.MNIST(
    root=".", train=True, download=True, transform=transforms.ToTensor()
)
testset = datasets.MNIST(
    root=".", train=False, download=True, transform=transforms.ToTensor()
)

train_loader = DataLoader(
    trainset,
    batch_size=args.batch_size,
    shuffle=True,
    collate_fn=batcher,
    num_workers=6,
)
test_loader = DataLoader(
    testset,
    batch_size=args.batch_size,
    shuffle=False,
    collate_fn=batcher,
    num_workers=6,
)

94
95

class MoNet(nn.Module):
96
    def __init__(self, n_kernels, in_feats, hiddens, out_feats):
97
98
99
100
101
102
        super(MoNet, self).__init__()
        self.pool = nn.MaxPool1d(2)
        self.layers = nn.ModuleList()
        self.readout = MaxPooling()

        # Input layer
103
        self.layers.append(GMMConv(in_feats, hiddens[0], 2, n_kernels))
104
105
106

        # Hidden layer
        for i in range(1, len(hiddens)):
107
108
109
            self.layers.append(
                GMMConv(hiddens[i - 1], hiddens[i], 2, n_kernels)
            )
110
111

        self.cls = nn.Sequential(
112
            nn.Linear(hiddens[-1], out_feats), nn.LogSoftmax()
113
114
115
116
        )

    def forward(self, g_arr, feat):
        for g, layer in zip(g_arr, self.layers):
117
118
119
120
121
122
            u = g.edata["u"]
            feat = (
                self.pool(layer(g, feat, u).transpose(-1, -2).unsqueeze(0))
                .squeeze(0)
                .transpose(-1, -2)
            )
123
124
        return self.cls(self.readout(g_arr[-1], feat))

125

126
class ChebNet(nn.Module):
127
    def __init__(self, k, in_feats, hiddens, out_feats):
128
129
130
131
132
133
        super(ChebNet, self).__init__()
        self.pool = nn.MaxPool1d(2)
        self.layers = nn.ModuleList()
        self.readout = MaxPooling()

        # Input layer
134
        self.layers.append(ChebConv(in_feats, hiddens[0], k))
135
136

        for i in range(1, len(hiddens)):
137
            self.layers.append(ChebConv(hiddens[i - 1], hiddens[i], k))
138
139

        self.cls = nn.Sequential(
140
            nn.Linear(hiddens[-1], out_feats), nn.LogSoftmax()
141
142
143
144
        )

    def forward(self, g_arr, feat):
        for g, layer in zip(g_arr, self.layers):
145
146
147
148
149
150
151
152
153
            feat = (
                self.pool(
                    layer(g, feat, [2] * g.batch_size)
                    .transpose(-1, -2)
                    .unsqueeze(0)
                )
                .squeeze(0)
                .transpose(-1, -2)
            )
154
155
        return self.cls(self.readout(g_arr[-1], feat))

156

157
if args.gpu == -1:
158
    device = torch.device("cpu")
159
160
161
else:
    device = torch.device(args.gpu)

162
if args.model == "chebnet":
163
164
165
166
167
168
169
170
171
172
    model = ChebNet(2, 1, [32, 64, 128, 256], 10)
else:
    model = MoNet(10, 1, [32, 64, 128, 256], 10)

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
log_interval = 50

for epoch in range(10):
173
    print("epoch {} starts".format(epoch))
174
175
176
177
178
179
    model.train()
    hit, tot = 0, 0
    loss_accum = 0
    for i, (g, x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)
180
        g = [g_i.to(device) for g_i in g]
181
182
183
184
185
186
187
        out = model(g, x)
        hit += (out.max(-1)[1] == y).sum().item()
        tot += len(y)
        loss = F.nll_loss(out, y)
        loss_accum += loss.item()

        if (i + 1) % log_interval == 0:
188
189
190
            print(
                "loss: {}, acc: {}".format(loss_accum / log_interval, hit / tot)
            )
191
192
193
194
195
196
197
198
199
200
201
202
            hit, tot = 0, 0
            loss_accum = 0

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    hit, tot = 0, 0
    for g, x, y in test_loader:
        x = x.to(device)
        y = y.to(device)
Harsh Sinha's avatar
Harsh Sinha committed
203
        g = [g_i.to(device) for g_i in g]
204
205
206
207
        out = model(g, x)
        hit += (out.max(-1)[1] == y).sum().item()
        tot += len(y)

208
    print("test acc: ", hit / tot)