entity_classify.py 7.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Code: https://github.com/tkipf/relational-gcn

Difference compared to tkipf/relation-gcn
* l2norm applied to all weights
* remove nodes that won't be touched
"""

import argparse
import time
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
13
14
15
from functools import partial

import dgl
16
17
import mxnet as mx
import mxnet.ndarray as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
18
19
import numpy as np
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
Minjie Wang's avatar
Minjie Wang committed
20
from dgl.nn.mxnet import RelGraphConv
21
22

from model import BaseRGCN
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
23
24
from mxnet import gluon

25
26
27

class EntityClassify(BaseRGCN):
    def build_input_layer(self):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
28
29
30
31
32
33
34
35
36
37
        return RelGraphConv(
            self.num_nodes,
            self.h_dim,
            self.num_rels,
            "basis",
            self.num_bases,
            activation=F.relu,
            self_loop=self.use_self_loop,
            dropout=self.dropout,
        )
38
39

    def build_hidden_layer(self, idx):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
40
41
42
43
44
45
46
47
48
49
        return RelGraphConv(
            self.h_dim,
            self.h_dim,
            self.num_rels,
            "basis",
            self.num_bases,
            activation=F.relu,
            self_loop=self.use_self_loop,
            dropout=self.dropout,
        )
50
51

    def build_output_layer(self):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
52
53
54
55
56
57
58
59
60
61
        return RelGraphConv(
            self.h_dim,
            self.out_dim,
            self.num_rels,
            "basis",
            self.num_bases,
            activation=None,
            self_loop=self.use_self_loop,
        )

62
63
64

def main(args):
    # load graph data
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
65
    if args.dataset == "aifb":
66
        dataset = AIFBDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
67
    elif args.dataset == "mutag":
68
        dataset = MUTAGDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
69
    elif args.dataset == "bgs":
70
        dataset = BGSDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
71
    elif args.dataset == "am":
72
73
74
75
76
77
78
79
80
81
        dataset = AMDataset()
    else:
        raise ValueError()

    # Load from hetero-graph
    hg = dataset[0]

    num_rels = len(hg.canonical_etypes)
    category = dataset.predict_category
    num_classes = dataset.num_classes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
82
83
84
85
86
    train_mask = hg.nodes[category].data.pop("train_mask")
    test_mask = hg.nodes[category].data.pop("test_mask")
    train_idx = mx.nd.array(np.nonzero(train_mask.asnumpy())[0], dtype="int64")
    test_idx = mx.nd.array(np.nonzero(test_mask.asnumpy())[0], dtype="int64")
    labels = mx.nd.array(hg.nodes[category].data.pop("labels"), dtype="int64")
87
88
89

    # split dataset into train, validate, test
    if args.validation:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
90
91
        val_idx = train_idx[: len(train_idx) // 5]
        train_idx = train_idx[len(train_idx) // 5 :]
92
93
94
    else:
        val_idx = train_idx

95
96
    # calculate norm for each edge type and store in edge
    for canonical_etype in hg.canonical_etypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
97
        u, v, eid = hg.all_edges(form="all", etype=canonical_etype)
98
        v = v.asnumpy()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
99
100
101
        _, inverse_index, count = np.unique(
            v, return_inverse=True, return_counts=True
        )
102
103
        degrees = count[inverse_index]
        norm = np.ones(eid.shape[0]) / degrees
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
104
105
106
        hg.edges[canonical_etype].data["norm"] = mx.nd.expand_dims(
            mx.nd.array(norm), axis=1
        )
107
108
109
110
111
112
113

    # get target category id
    category_id = len(hg.ntypes)
    for i, ntype in enumerate(hg.ntypes):
        if ntype == category:
            category_id = i

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
114
    g = dgl.to_homogeneous(hg, edata=["norm"])
115
116
    num_nodes = g.number_of_nodes()
    node_ids = mx.nd.arange(num_nodes)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
117
    edge_norm = g.edata["norm"]
118
119
120
121
    edge_type = g.edata[dgl.ETYPE]

    # find out the target node ids in g
    node_tids = g.ndata[dgl.NTYPE]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
122
123
    loc = node_tids == category_id
    loc = mx.nd.array(np.nonzero(loc.asnumpy())[0], dtype="int64")
124
125
    target_idx = node_ids[loc]

Minjie Wang's avatar
Minjie Wang committed
126
    # since the nodes are featureless, the input feature is then the node id.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
127
    feats = mx.nd.arange(num_nodes, dtype="int32")
128
129
130
131
132

    # check cuda
    use_cuda = args.gpu >= 0
    if use_cuda:
        ctx = mx.gpu(args.gpu)
Minjie Wang's avatar
Minjie Wang committed
133
        feats = feats.as_in_context(ctx)
134
135
136
137
        edge_type = edge_type.as_in_context(ctx)
        edge_norm = edge_norm.as_in_context(ctx)
        labels = labels.as_in_context(ctx)
        train_idx = train_idx.as_in_context(ctx)
138
        g = g.to(ctx)
139
140
141
142
    else:
        ctx = mx.cpu(0)

    # create model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
143
144
145
146
147
148
149
150
151
152
153
    model = EntityClassify(
        num_nodes,
        args.n_hidden,
        num_classes,
        num_rels,
        num_bases=args.n_bases,
        num_hidden_layers=args.n_layers - 2,
        dropout=args.dropout,
        use_self_loop=args.use_self_loop,
        gpu_id=args.gpu,
    )
154
155
156
    model.initialize(ctx=ctx)

    # optimizer
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
157
158
159
160
161
    trainer = gluon.Trainer(
        model.collect_params(),
        "adam",
        {"learning_rate": args.lr, "wd": args.l2norm},
    )
162
163
164
165
166
167
168
169
170
    loss_fcn = gluon.loss.SoftmaxCELoss(from_logits=False)

    # training loop
    print("start training...")
    forward_time = []
    backward_time = []
    for epoch in range(args.n_epochs):
        t0 = time.time()
        with mx.autograd.record():
Minjie Wang's avatar
Minjie Wang committed
171
            pred = model(g, feats, edge_type, edge_norm)
172
            pred = pred[target_idx]
173
174
175
176
177
178
179
180
            loss = loss_fcn(pred[train_idx], labels[train_idx])
        t1 = time.time()
        loss.backward()
        trainer.step(len(train_idx))
        t2 = time.time()

        forward_time.append(t1 - t0)
        backward_time.append(t2 - t1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
181
182
183
184
185
        print(
            "Epoch {:05d} | Train Forward Time(s) {:.4f} | Backward Time(s) {:.4f}".format(
                epoch, forward_time[-1], backward_time[-1]
            )
        )
186

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        train_acc = (
            F.sum(
                mx.nd.cast(pred[train_idx].argmax(axis=1), "int64")
                == labels[train_idx]
            ).asscalar()
            / train_idx.shape[0]
        )
        val_acc = F.sum(
            mx.nd.cast(pred[val_idx].argmax(axis=1), "int64") == labels[val_idx]
        ).asscalar() / len(val_idx)
        print(
            "Train Accuracy: {:.4f} | Validation Accuracy: {:.4f}".format(
                train_acc, val_acc
            )
        )
202
203
    print()

Minjie Wang's avatar
Minjie Wang committed
204
    logits = model.forward(g, feats, edge_type, edge_norm)
205
    logits = logits[target_idx]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
206
207
208
    test_acc = F.sum(
        mx.nd.cast(logits[test_idx].argmax(axis=1), "int64") == labels[test_idx]
    ).asscalar() / len(test_idx)
209
210
211
    print("Test Accuracy: {:.4f}".format(test_acc))
    print()

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    print(
        "Mean forward time: {:4f}".format(
            np.mean(forward_time[len(forward_time) // 4 :])
        )
    )
    print(
        "Mean backward time: {:4f}".format(
            np.mean(backward_time[len(backward_time) // 4 :])
        )
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGCN")
    parser.add_argument(
        "--dropout", type=float, default=0, help="dropout probability"
    )
    parser.add_argument(
        "--n-hidden", type=int, default=16, help="number of hidden units"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    parser.add_argument(
        "--n-bases",
        type=int,
        default=-1,
        help="number of filter weight matrices, default: -1 [use all]",
    )
    parser.add_argument(
        "--n-layers", type=int, default=2, help="number of propagation rounds"
    )
    parser.add_argument(
        "-e",
        "--n-epochs",
        type=int,
        default=50,
        help="number of training epochs",
    )
    parser.add_argument(
        "-d", "--dataset", type=str, required=True, help="dataset to use"
    )
    parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
    parser.add_argument(
        "--use-self-loop",
        default=False,
        action="store_true",
        help="include self feature as a special relation",
    )
260
    fp = parser.add_mutually_exclusive_group(required=False)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
261
262
    fp.add_argument("--validation", dest="validation", action="store_true")
    fp.add_argument("--testing", dest="validation", action="store_false")
263
264
265
266
    parser.set_defaults(validation=True)

    args = parser.parse_args()
    print(args)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
267
    args.bfs_level = args.n_layers + 1  # pruning used nodes for memory
Minjie Wang's avatar
Minjie Wang committed
268
    main(args)