main.py 11 KB
Newer Older
1
2
3
""" The main file to train a MixHop model using a full graph """

import argparse
xnouhz's avatar
xnouhz committed
4
import copy
5
6
import random

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
8
9
import dgl
import dgl.function as fn

10
import numpy as np
11
12
import torch
import torch.nn as nn
13
14
import torch.optim as optim
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
15
from tqdm import trange
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


class MixHopConv(nn.Module):
    r"""

    Description
    -----------
    MixHop Graph Convolutional layer from paper `MixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified Neighborhood Mixing
     <https://arxiv.org/pdf/1905.00067.pdf>`__.

    .. math::
        H^{(i+1)} =\underset{j \in P}{\Bigg\Vert} \sigma\left(\widehat{A}^j H^{(i)} W_j^{(i)}\right),

    where :math:`\widehat{A}` denotes the symmetrically normalized adjacencymatrix with self-connections,
    :math:`D_{ii} = \sum_{j=0} \widehat{A}_{ij}` its diagonal degree matrix,
    :math:`W_j^{(i)}` denotes the trainable weight matrix of different MixHop layers.

    Parameters
    ----------
    in_dim : int
        Input feature size. i.e, the number of dimensions of :math:`H^{(i)}`.
    out_dim : int
        Output feature size for each power.
    p: list
        List of powers of adjacency matrix. Defaults: ``[0, 1, 2]``.
    dropout: float, optional
        Dropout rate on node features. Defaults: ``0``.
    activation: callable activation function/layer or None, optional
        If not None, applies an activation function to the updated node features.
        Default: ``None``.
    batchnorm: bool, optional
        If True, use batch normalization. Defaults: ``False``.
    """
49
50
51
52
53
54
55
56
57
58

    def __init__(
        self,
        in_dim,
        out_dim,
        p=[0, 1, 2],
        dropout=0,
        activation=None,
        batchnorm=False,
    ):
59
60
61
62
63
64
65
66
67
68
69
70
71
        super(MixHopConv, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.p = p
        self.activation = activation
        self.batchnorm = batchnorm

        # define dropout layer
        self.dropout = nn.Dropout(dropout)

        # define batch norm layer
        if self.batchnorm:
            self.bn = nn.BatchNorm1d(out_dim * len(p))
72

73
        # define weight dict for each power j
74
75
76
        self.weights = nn.ModuleDict(
            {str(j): nn.Linear(in_dim, out_dim, bias=False) for j in p}
        )
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    def forward(self, graph, feats):
        with graph.local_scope():
            # assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
            degs = graph.in_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
            max_j = max(self.p) + 1
            outputs = []
            for j in range(max_j):
                if j in self.p:
                    output = self.weights[str(j)](feats)
                    outputs.append(output)

                feats = feats * norm
91
92
93
                graph.ndata["h"] = feats
                graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
                feats = graph.ndata.pop("h")
94
                feats = feats * norm
95

96
            final = torch.cat(outputs, dim=1)
97

98
99
            if self.batchnorm:
                final = self.bn(final)
100

101
102
            if self.activation is not None:
                final = self.activation(final)
103

104
105
106
107
            final = self.dropout(final)

            return final

108

109
class MixHop(nn.Module):
110
111
112
113
114
115
116
117
118
119
120
121
    def __init__(
        self,
        in_dim,
        hid_dim,
        out_dim,
        num_layers=2,
        p=[0, 1, 2],
        input_dropout=0.0,
        layer_dropout=0.0,
        activation=None,
        batchnorm=False,
    ):
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        super(MixHop, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        self.num_layers = num_layers
        self.p = p
        self.input_dropout = input_dropout
        self.layer_dropout = layer_dropout
        self.activation = activation
        self.batchnorm = batchnorm

        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(self.input_dropout)

        # Input layer
137
138
139
140
141
142
143
144
145
146
147
        self.layers.append(
            MixHopConv(
                self.in_dim,
                self.hid_dim,
                p=self.p,
                dropout=self.input_dropout,
                activation=self.activation,
                batchnorm=self.batchnorm,
            )
        )

148
149
        # Hidden layers with n - 1 MixHopConv layers
        for i in range(self.num_layers - 2):
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            self.layers.append(
                MixHopConv(
                    self.hid_dim * len(args.p),
                    self.hid_dim,
                    p=self.p,
                    dropout=self.layer_dropout,
                    activation=self.activation,
                    batchnorm=self.batchnorm,
                )
            )

        self.fc_layers = nn.Linear(
            self.hid_dim * len(args.p), self.out_dim, bias=False
        )
164
165
166
167
168

    def forward(self, graph, feats):
        feats = self.dropout(feats)
        for layer in self.layers:
            feats = layer(graph, feats)
169

170
171
172
173
        feats = self.fc_layers(feats)

        return feats

174

175
176
177
def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
178
    if args.dataset == "Cora":
179
        dataset = CoraGraphDataset()
180
    elif args.dataset == "Citeseer":
181
        dataset = CiteseerGraphDataset()
182
    elif args.dataset == "Pubmed":
183
184
        dataset = PubmedGraphDataset()
    else:
185
186
        raise ValueError("Dataset {} is invalid.".format(args.dataset))

187
188
189
190
191
    graph = dataset[0]
    graph = dgl.add_self_loop(graph)

    # check cuda
    if args.gpu >= 0 and torch.cuda.is_available():
192
        device = "cuda:{}".format(args.gpu)
193
    else:
194
        device = "cpu"
195
196
197
198
199

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
200
    labels = graph.ndata.pop("label").to(device).long()
201
202

    # Extract node features
203
    feats = graph.ndata.pop("feat").to(device)
204
205
206
    n_features = feats.shape[-1]

    # retrieve masks for train/validation/test
207
208
209
    train_mask = graph.ndata.pop("train_mask")
    val_mask = graph.ndata.pop("val_mask")
    test_mask = graph.ndata.pop("test_mask")
210
211
212
213
214
215
216
217

    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)

    graph = graph.to(device)

    # Step 2: Create model =================================================================== #
218
219
220
221
222
223
224
225
226
227
228
229
    model = MixHop(
        in_dim=n_features,
        hid_dim=args.hid_dim,
        out_dim=n_classes,
        num_layers=args.num_layers,
        p=args.p,
        input_dropout=args.input_dropout,
        layer_dropout=args.layer_dropout,
        activation=torch.tanh,
        batchnorm=True,
    )

230
    model = model.to(device)
xnouhz's avatar
xnouhz committed
231
    best_model = copy.deepcopy(model)
232
233
234
235
236
237
238
239
240

    # Step 3: Create training components ===================================================== #
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.lamb)
    scheduler = optim.lr_scheduler.StepLR(opt, args.step_size, gamma=args.gamma)

    # Step 4: training epoches =============================================================== #
    acc = 0
    no_improvement = 0
241
    epochs = trange(args.epochs, desc="Accuracy & Loss")
242
243
244
245
246
247
248
249
250

    for _ in epochs:
        # Training using a full graph
        model.train()

        logits = model(graph, feats)

        # compute loss
        train_loss = loss_fn(logits[train_idx], labels[train_idx])
251
252
253
        train_acc = torch.sum(
            logits[train_idx].argmax(dim=1) == labels[train_idx]
        ).item() / len(train_idx)
254
255
256
257
258
259
260
261
262
263
264

        # backward
        opt.zero_grad()
        train_loss.backward()
        opt.step()

        # Validation using a full graph
        model.eval()

        with torch.no_grad():
            valid_loss = loss_fn(logits[val_idx], labels[val_idx])
265
266
267
            valid_acc = torch.sum(
                logits[val_idx].argmax(dim=1) == labels[val_idx]
            ).item() / len(val_idx)
268
269

        # Print out performance
270
271
272
273
274
275
        epochs.set_description(
            "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
                train_acc, train_loss.item(), valid_acc, valid_loss.item()
            )
        )

276
277
278
        if valid_acc < acc:
            no_improvement += 1
            if no_improvement == args.early_stopping:
279
                print("Early stop.")
280
281
282
283
                break
        else:
            no_improvement = 0
            acc = valid_acc
xnouhz's avatar
xnouhz committed
284
            best_model = copy.deepcopy(model)
285

286
287
        scheduler.step()

xnouhz's avatar
xnouhz committed
288
289
    best_model.eval()
    logits = best_model(graph, feats)
290
291
292
    test_acc = torch.sum(
        logits[test_idx].argmax(dim=1) == labels[test_idx]
    ).item() / len(test_idx)
293
294
295
296

    print("Test Acc {:.4f}".format(test_acc))
    return test_acc

297

298
299
300
301
if __name__ == "__main__":
    """
    MixHop Model Hyperparameters
    """
302
    parser = argparse.ArgumentParser(description="MixHop GCN")
303
304

    # data source params
305
306
307
    parser.add_argument(
        "--dataset", type=str, default="Cora", help="Name of dataset."
    )
308
    # cuda params
309
310
311
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
312
    # training params
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    parser.add_argument(
        "--epochs", type=int, default=2000, help="Training epochs."
    )
    parser.add_argument(
        "--early-stopping",
        type=int,
        default=200,
        help="Patient epochs to wait before early stopping.",
    )
    parser.add_argument("--lr", type=float, default=0.5, help="Learning rate.")
    parser.add_argument("--lamb", type=float, default=5e-4, help="L2 reg.")
    parser.add_argument(
        "--step-size",
        type=int,
        default=40,
        help="Period of learning rate decay.",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.01,
        help="Multiplicative factor of learning rate decay.",
    )
336
    # model params
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    parser.add_argument(
        "--hid-dim", type=int, default=60, help="Hidden layer dimensionalities."
    )
    parser.add_argument(
        "--num-layers", type=int, default=4, help="Number of GNN layers."
    )
    parser.add_argument(
        "--input-dropout",
        type=float,
        default=0.7,
        help="Dropout applied at input layer.",
    )
    parser.add_argument(
        "--layer-dropout",
        type=float,
        default=0.9,
        help="Dropout applied at hidden layers.",
    )
    parser.add_argument(
        "--p", nargs="+", type=int, help="List of powers of adjacency matrix."
    )
358
359
360
361
362
363
364
365
366
367

    parser.set_defaults(p=[0, 1, 2])

    args = parser.parse_args()
    print(args)

    acc_lists = []

    for _ in range(100):
        acc_lists.append(main(args))
368

369
370
371
372
373
    acc_lists.sort()
    acc_lists_top = np.array(acc_lists[50:])

    mean = np.around(np.mean(acc_lists_top, axis=0), decimals=3)
    std = np.around(np.std(acc_lists_top, axis=0), decimals=3)
374
375
376
377
    print("Total acc: ", acc_lists)
    print("Top 50 acc:", acc_lists_top)
    print("mean", mean)
    print("std", std)