citation.py 5.37 KB
Newer Older
1
2
import argparse
import time
3

4
import networkx as nx
5
import numpy as np
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
9

10
from dgl import DGLGraph
11
from dgl.data import load_data, register_data_args
12
13
14
15
from dgl.nn.pytorch.conv import GMMConv


class MoNet(nn.Module):
16
17
18
19
20
21
22
23
24
25
26
    def __init__(
        self,
        g,
        in_feats,
        n_hidden,
        out_feats,
        n_layers,
        dim,
        n_kernels,
        dropout,
    ):
27
28
29
30
31
32
        super(MoNet, self).__init__()
        self.g = g
        self.layers = nn.ModuleList()
        self.pseudo_proj = nn.ModuleList()

        # Input layer
33
34
        self.layers.append(GMMConv(in_feats, n_hidden, dim, n_kernels))
        self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))
35
36
37
38

        # Hidden layer
        for _ in range(n_layers - 1):
            self.layers.append(GMMConv(n_hidden, n_hidden, dim, n_kernels))
39
            self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))
40
41
42

        # Output layer
        self.layers.append(GMMConv(n_hidden, out_feats, dim, n_kernels))
43
        self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))
44
45
46
47
48
49
50
        self.dropout = nn.Dropout(dropout)

    def forward(self, feat, pseudo):
        h = feat
        for i in range(len(self.layers)):
            if i != 0:
                h = self.dropout(h)
51
            h = self.layers[i](self.g, h, self.pseudo_proj[i](pseudo))
52
53
        return h

54

55
56
57
58
59
60
61
62
63
64
def evaluate(model, features, pseudo, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features, pseudo)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

65

66
67
68
def main(args):
    # load and preprocess dataset
    data = load_data(args)
69
70
71
    g = data[0]
    if args.gpu < 0:
        cuda = False
72
    else:
73
74
        cuda = True
        g = g.to(args.gpu)
75
76
77
78
79
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
80
81
    in_feats = features.shape[1]
    n_classes = data.num_labels
82
    n_edges = g.number_of_edges()
83
84
    print(
        """----Data statistics------'
85
86
87
88
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
89
90
91
92
93
94
95
96
97
      #Test samples %d"""
        % (
            n_edges,
            n_classes,
            train_mask.sum().item(),
            val_mask.sum().item(),
            test_mask.sum().item(),
        )
    )
98
99

    # graph preprocess and calculate normalization factor
100
    g = g.remove_self_loop().add_self_loop()
101
    n_edges = g.number_of_edges()
102
103
104
105
    us, vs = g.edges(order="eid")
    udeg, vdeg = 1 / torch.sqrt(g.in_degrees(us).float()), 1 / torch.sqrt(
        g.in_degrees(vs).float()
    )
106
    pseudo = torch.cat([udeg.unsqueeze(1), vdeg.unsqueeze(1)], dim=1)
107
108

    # create GraphSAGE model
109
110
111
112
113
114
115
116
117
118
    model = MoNet(
        g,
        in_feats,
        args.n_hidden,
        n_classes,
        args.n_layers,
        args.pseudo_dim,
        args.n_kernels,
        args.dropout,
    )
119
120
121
122
123
124

    if cuda:
        model.cuda()
    loss_fcn = torch.nn.CrossEntropyLoss()

    # use optimizer
125
126
127
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        logits = model(features, pseudo)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        acc = evaluate(model, features, pseudo, labels, val_mask)
147
148
149
150
151
152
153
154
155
156
        print(
            "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
            "ETputs(KTEPS) {:.2f}".format(
                epoch,
                np.mean(dur),
                loss.item(),
                acc,
                n_edges / np.mean(dur) / 1000,
            )
        )
157
158
159
160
161
162

    print()
    acc = evaluate(model, features, pseudo, labels, test_mask)
    print("Test Accuracy {:.4f}".format(acc))


163
164
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MoNet on citation network")
165
    register_data_args(parser)
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    parser.add_argument(
        "--dropout", type=float, default=0.5, help="dropout probability"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    parser.add_argument(
        "--n-epochs", type=int, default=200, help="number of training epochs"
    )
    parser.add_argument(
        "--n-hidden", type=int, default=16, help="number of hidden gcn units"
    )
    parser.add_argument(
        "--n-layers", type=int, default=1, help="number of hidden gcn layers"
    )
    parser.add_argument(
        "--pseudo-dim",
        type=int,
        default=2,
        help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed",
    )
    parser.add_argument(
        "--n-kernels",
        type=int,
        default=3,
        help="Number of kernels in GMMConv layer",
    )
    parser.add_argument(
        "--weight-decay", type=float, default=5e-4, help="Weight for L2 loss"
    )
195
196
197
198
    args = parser.parse_args()
    print(args)

    main(args)