"git@developer.sourcefind.cn:change/sglang.git" did not exist on "0edda32001938b578976409216bc6f9f36f719df"
entity_classify.py 9.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Code: https://github.com/tkipf/relational-gcn

Difference compared to tkipf/relation-gcn
* l2norm applied to all weights
* remove nodes that won't be touched
"""

import argparse
import time
13
14
15
from functools import partial

import numpy as np
16
import tensorflow as tf
17
from model import BaseRGCN
18
from tensorflow.keras import layers
19

20
import dgl
21
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
22
23
24
25
26
27
28
29
30
from dgl.nn.tensorflow import RelGraphConv


class EntityClassify(BaseRGCN):
    def create_features(self):
        features = tf.range(self.num_nodes)
        return features

    def build_input_layer(self):
31
32
33
34
35
36
37
38
39
40
        return RelGraphConv(
            self.num_nodes,
            self.h_dim,
            self.num_rels,
            "basis",
            self.num_bases,
            activation=tf.nn.relu,
            self_loop=self.use_self_loop,
            dropout=self.dropout,
        )
41
42

    def build_hidden_layer(self, idx):
43
44
45
46
47
48
49
50
51
52
        return RelGraphConv(
            self.h_dim,
            self.h_dim,
            self.num_rels,
            "basis",
            self.num_bases,
            activation=tf.nn.relu,
            self_loop=self.use_self_loop,
            dropout=self.dropout,
        )
53
54

    def build_output_layer(self):
55
56
57
58
59
60
61
62
63
64
        return RelGraphConv(
            self.h_dim,
            self.out_dim,
            self.num_rels,
            "basis",
            self.num_bases,
            activation=partial(tf.nn.softmax, axis=1),
            self_loop=self.use_self_loop,
        )

65
66
67
68
69
70
71
72

def acc(logits, labels, mask):
    logits = tf.gather(logits, mask)
    labels = tf.gather(labels, mask)
    indices = tf.math.argmax(logits, axis=1)
    acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))
    return acc

73

74
75
def main(args):
    # load graph data
76
    if args.dataset == "aifb":
77
        dataset = AIFBDataset()
78
    elif args.dataset == "mutag":
79
        dataset = MUTAGDataset()
80
    elif args.dataset == "bgs":
81
        dataset = BGSDataset()
82
    elif args.dataset == "am":
83
        dataset = AMDataset()
84
    else:
85
86
87
88
89
90
91
92
93
94
        raise ValueError()

    # preprocessing in cpu
    with tf.device("/cpu:0"):
        # Load from hetero-graph
        hg = dataset[0]

        num_rels = len(hg.canonical_etypes)
        category = dataset.predict_category
        num_classes = dataset.num_classes
95
96
        train_mask = hg.nodes[category].data.pop("train_mask")
        test_mask = hg.nodes[category].data.pop("test_mask")
97
98
        train_idx = tf.squeeze(tf.where(train_mask))
        test_idx = tf.squeeze(tf.where(test_mask))
99
        labels = hg.nodes[category].data.pop("labels")
100
101
102

        # split dataset into train, validate, test
        if args.validation:
103
104
            val_idx = train_idx[: len(train_idx) // 5]
            train_idx = train_idx[len(train_idx) // 5 :]
105
106
107
108
109
        else:
            val_idx = train_idx

        # calculate norm for each edge type and store in edge
        for canonical_etype in hg.canonical_etypes:
110
            u, v, eid = hg.all_edges(form="all", etype=canonical_etype)
111
112
113
114
            _, inverse_index, count = tf.unique_with_counts(v)
            degrees = tf.gather(count, inverse_index)
            norm = tf.ones(eid.shape[0]) / tf.cast(degrees, tf.float32)
            norm = tf.expand_dims(norm, 1)
115
            hg.edges[canonical_etype].data["norm"] = norm
116
117
118
119
120
121
122
123

        # get target category id
        category_id = len(hg.ntypes)
        for i, ntype in enumerate(hg.ntypes):
            if ntype == category:
                category_id = i

        # edge type and normalization factor
124
        g = dgl.to_homogeneous(hg, edata=["norm"])
125
126
127
128
129
130
131

    # check cuda
    if args.gpu < 0:
        device = "/cpu:0"
        use_cuda = False
    else:
        device = "/gpu:{}".format(args.gpu)
132
        g = g.to(device)
133
        use_cuda = True
134
135
    num_nodes = g.number_of_nodes()
    node_ids = tf.range(num_nodes, dtype=tf.int64)
136
    edge_norm = g.edata["norm"]
137
    edge_type = tf.cast(g.edata[dgl.ETYPE], tf.int64)
138

139
140
    # find out the target node ids in g
    node_tids = g.ndata[dgl.NTYPE]
141
    loc = node_tids == category_id
142
    target_idx = tf.squeeze(tf.where(loc))
143

144
145
146
147
    # since the nodes are featureless, the input feature is then the node id.
    feats = tf.range(num_nodes, dtype=tf.int64)

    with tf.device(device):
148
        # create model
149
150
151
152
153
154
155
156
157
158
159
        model = EntityClassify(
            num_nodes,
            args.n_hidden,
            num_classes,
            num_rels,
            num_bases=args.n_bases,
            num_hidden_layers=args.n_layers - 2,
            dropout=args.dropout,
            use_self_loop=args.use_self_loop,
            use_cuda=use_cuda,
        )
160
161

        # optimizer
162
        optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)
163
164
165
166
167
        # training loop
        print("start training...")
        forward_time = []
        backward_time = []
        loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(
168
169
            from_logits=False
        )
170
171
172
173
        for epoch in range(args.n_epochs):
            t0 = time.time()
            with tf.GradientTape() as tape:
                logits = model(g, feats, edge_type, edge_norm)
174
                logits = tf.gather(logits, target_idx)
175
176
177
                loss = loss_fcn(
                    tf.gather(labels, train_idx), tf.gather(logits, train_idx)
                )
178
                # Manually Weight Decay
179
                # We found Tensorflow has a different implementation on weight decay
180
181
182
                # of Adam(W) optimizer with PyTorch. And this results in worse results.
                # Manually adding weights to the loss to do weight decay solves this problem.
                for weight in model.trainable_weights:
183
                    loss = loss + args.l2norm * tf.nn.l2_loss(weight)
184
185
186
187
188
189
190
                t1 = time.time()
                grads = tape.gradient(loss, model.trainable_weights)
                optimizer.apply_gradients(zip(grads, model.trainable_weights))
                t2 = time.time()

            forward_time.append(t1 - t0)
            backward_time.append(t2 - t1)
191
192
193
194
195
            print(
                "Epoch {:05d} | Train Forward Time(s) {:.4f} | Backward Time(s) {:.4f}".format(
                    epoch, forward_time[-1], backward_time[-1]
                )
            )
196
            train_acc = acc(logits, labels, train_idx)
197
198
199
            val_loss = loss_fcn(
                tf.gather(labels, val_idx), tf.gather(logits, val_idx)
            )
200
            val_acc = acc(logits, labels, val_idx)
201
202
203
204
205
206
207
208
            print(
                "Train Accuracy: {:.4f} | Train Loss: {:.4f} | Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format(
                    train_acc,
                    loss.numpy().item(),
                    val_acc,
                    val_loss.numpy().item(),
                )
            )
209
210
211
        print()

        logits = model(g, feats, edge_type, edge_norm)
212
        logits = tf.gather(logits, target_idx)
213
214
215
        test_loss = loss_fcn(
            tf.gather(labels, test_idx), tf.gather(logits, test_idx)
        )
216
        test_acc = acc(logits, labels, test_idx)
217
218
219
220
221
        print(
            "Test Accuracy: {:.4f} | Test loss: {:.4f}".format(
                test_acc, test_loss.numpy().item()
            )
        )
222
223
        print()

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        print(
            "Mean forward time: {:4f}".format(
                np.mean(forward_time[len(forward_time) // 4 :])
            )
        )
        print(
            "Mean backward time: {:4f}".format(
                np.mean(backward_time[len(backward_time) // 4 :])
            )
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGCN")
    parser.add_argument(
        "--dropout", type=float, default=0, help="dropout probability"
    )
    parser.add_argument(
        "--n-hidden", type=int, default=16, help="number of hidden units"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    parser.add_argument(
        "--n-bases",
        type=int,
        default=-1,
        help="number of filter weight matrices, default: -1 [use all]",
    )
    parser.add_argument(
        "--n-layers", type=int, default=2, help="number of propagation rounds"
    )
    parser.add_argument(
        "-e",
        "--n-epochs",
        type=int,
        default=50,
        help="number of training epochs",
    )
    parser.add_argument(
        "-d", "--dataset", type=str, required=True, help="dataset to use"
    )
    parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
    parser.add_argument(
        "--use-self-loop",
        default=False,
        action="store_true",
        help="include self feature as a special relation",
    )
272
    fp = parser.add_mutually_exclusive_group(required=False)
273
274
    fp.add_argument("--validation", dest="validation", action="store_true")
    fp.add_argument("--testing", dest="validation", action="store_false")
275
276
277
278
    parser.set_defaults(validation=True)

    args = parser.parse_args()
    print(args)
279
    args.bfs_level = args.n_layers + 1  # pruning used nodes for memory
280
    main(args)