main.py 10.7 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
import argparse
2
3
4
from time import time

import numpy as np
KounianhuaDu's avatar
KounianhuaDu committed
5
6
import torch as th
import torch.nn as nn
7
8
9
10
import torch.nn.functional as F
import torch.optim as optim
from data_loader import Data
from models import CompGCN_ConvE
KounianhuaDu's avatar
KounianhuaDu committed
11
12
from utils import in_out_norm

13
import dgl.function as fn
KounianhuaDu's avatar
KounianhuaDu committed
14
15


16
17
# predict the tail for (head, rel, -1) or head for (-1, rel, tail)
def predict(model, graph, device, data_iter, split="valid", mode="tail"):
KounianhuaDu's avatar
KounianhuaDu committed
18
19
20
    model.eval()
    with th.no_grad():
        results = {}
21
22
        train_iter = iter(data_iter["{}_{}".format(split, mode)])

KounianhuaDu's avatar
KounianhuaDu committed
23
24
        for step, batch in enumerate(train_iter):
            triple, label = batch[0].to(device), batch[1].to(device)
25
26
27
28
29
30
            sub, rel, obj, label = (
                triple[:, 0],
                triple[:, 1],
                triple[:, 2],
                label,
            )
KounianhuaDu's avatar
KounianhuaDu committed
31
            pred = model(graph, sub, rel)
32
            b_range = th.arange(pred.size()[0], device=device)
KounianhuaDu's avatar
KounianhuaDu committed
33
34
35
36
            target_pred = pred[b_range, obj]
            pred = th.where(label.byte(), -th.ones_like(pred) * 10000000, pred)
            pred[b_range, obj] = target_pred

37
38
39
40
41
42
43
44
45
            # compute metrics
            ranks = (
                1
                + th.argsort(
                    th.argsort(pred, dim=1, descending=True),
                    dim=1,
                    descending=False,
                )[b_range, obj]
            )
KounianhuaDu's avatar
KounianhuaDu committed
46
            ranks = ranks.float()
47
48
49
50
51
52
53
54
55
56
            results["count"] = th.numel(ranks) + results.get("count", 0.0)
            results["mr"] = th.sum(ranks).item() + results.get("mr", 0.0)
            results["mrr"] = th.sum(1.0 / ranks).item() + results.get(
                "mrr", 0.0
            )
            for k in [1, 3, 10]:
                results["hits@{}".format(k)] = th.numel(
                    ranks[ranks <= (k)]
                ) + results.get("hits@{}".format(k), 0.0)

KounianhuaDu's avatar
KounianhuaDu committed
57
58
    return results

59
60
61
62
63
64

# evaluation function, evaluate the head and tail prediction and then combine the results
def evaluate(model, graph, device, data_iter, split="valid"):
    # predict for head and tail
    left_results = predict(model, graph, device, data_iter, split, mode="tail")
    right_results = predict(model, graph, device, data_iter, split, mode="head")
KounianhuaDu's avatar
KounianhuaDu committed
65
    results = {}
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    count = float(left_results["count"])

    # combine the head and tail prediction results
    # Metrics: MRR, MR, and Hit@k
    results["left_mr"] = round(left_results["mr"] / count, 5)
    results["left_mrr"] = round(left_results["mrr"] / count, 5)
    results["right_mr"] = round(right_results["mr"] / count, 5)
    results["right_mrr"] = round(right_results["mrr"] / count, 5)
    results["mr"] = round(
        (left_results["mr"] + right_results["mr"]) / (2 * count), 5
    )
    results["mrr"] = round(
        (left_results["mrr"] + right_results["mrr"]) / (2 * count), 5
    )
    for k in [1, 3, 10]:
        results["left_hits@{}".format(k)] = round(
            left_results["hits@{}".format(k)] / count, 5
        )
        results["right_hits@{}".format(k)] = round(
            right_results["hits@{}".format(k)] / count, 5
        )
        results["hits@{}".format(k)] = round(
            (
                left_results["hits@{}".format(k)]
                + right_results["hits@{}".format(k)]
            )
            / (2 * count),
            5,
        )
    return results

KounianhuaDu's avatar
KounianhuaDu committed
97
98
99
100
101
102

def main(args):

    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # check cuda
    if args.gpu >= 0 and th.cuda.is_available():
103
        device = "cuda:{}".format(args.gpu)
KounianhuaDu's avatar
KounianhuaDu committed
104
    else:
105
106
107
108
109
110
111
        device = "cpu"

    # construct graph, split in/out edges and prepare train/validation/test data_loader
    data = Data(
        args.dataset, args.lbl_smooth, args.num_workers, args.batch_size
    )
    data_iter = data.data_iter  # train/validation/test data_loader
KounianhuaDu's avatar
KounianhuaDu committed
112
    graph = data.g.to(device)
113
    num_rel = th.max(graph.edata["etype"]).item() + 1
KounianhuaDu's avatar
KounianhuaDu committed
114

115
    # Compute in/out edge norms and store in edata
KounianhuaDu's avatar
KounianhuaDu committed
116
117
118
    graph = in_out_norm(graph)

    # Step 2: Create model =================================================================== #
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    compgcn_model = CompGCN_ConvE(
        num_bases=args.num_bases,
        num_rel=num_rel,
        num_ent=graph.num_nodes(),
        in_dim=args.init_dim,
        layer_size=args.layer_size,
        comp_fn=args.opn,
        batchnorm=True,
        dropout=args.dropout,
        layer_dropout=args.layer_dropout,
        num_filt=args.num_filt,
        hid_drop=args.hid_drop,
        feat_drop=args.feat_drop,
        ker_sz=args.ker_sz,
        k_w=args.k_w,
        k_h=args.k_h,
    )
KounianhuaDu's avatar
KounianhuaDu committed
136
137
138
139
    compgcn_model = compgcn_model.to(device)

    # Step 3: Create training components ===================================================== #
    loss_fn = th.nn.BCELoss()
140
141
142
143
    optimizer = optim.Adam(
        compgcn_model.parameters(), lr=args.lr, weight_decay=args.l2
    )

KounianhuaDu's avatar
KounianhuaDu committed
144
145
146
147
148
149
    # Step 4: training epoches =============================================================== #
    best_mrr = 0.0
    kill_cnt = 0
    for epoch in range(args.max_epochs):
        # Training and validation using a full graph
        compgcn_model.train()
150
        train_loss = []
KounianhuaDu's avatar
KounianhuaDu committed
151
        t0 = time()
152
        for step, batch in enumerate(data_iter["train"]):
KounianhuaDu's avatar
KounianhuaDu committed
153
            triple, label = batch[0].to(device), batch[1].to(device)
154
155
156
157
158
159
            sub, rel, obj, label = (
                triple[:, 0],
                triple[:, 1],
                triple[:, 2],
                label,
            )
KounianhuaDu's avatar
KounianhuaDu committed
160
            logits = compgcn_model(graph, sub, rel)
161

KounianhuaDu's avatar
KounianhuaDu committed
162
163
            # compute loss
            tr_loss = loss_fn(logits, label)
nxznm's avatar
nxznm committed
164
            train_loss.append(tr_loss.item())
KounianhuaDu's avatar
KounianhuaDu committed
165
166
167
168
169
170
171
172

            # backward
            optimizer.zero_grad()
            tr_loss.backward()
            optimizer.step()

        train_loss = np.sum(train_loss)

173
174
175
176
        t1 = time()
        val_results = evaluate(
            compgcn_model, graph, device, data_iter, split="valid"
        )
KounianhuaDu's avatar
KounianhuaDu committed
177
178
        t2 = time()

179
180
181
        # validate
        if val_results["mrr"] > best_mrr:
            best_mrr = val_results["mrr"]
KounianhuaDu's avatar
KounianhuaDu committed
182
            best_epoch = epoch
183
184
185
            th.save(
                compgcn_model.state_dict(), "comp_link" + "_" + args.dataset
            )
KounianhuaDu's avatar
KounianhuaDu committed
186
187
188
189
            kill_cnt = 0
            print("saving model...")
        else:
            kill_cnt += 1
nxznm's avatar
nxznm committed
190
            if kill_cnt > 100:
191
                print("early stop.")
KounianhuaDu's avatar
KounianhuaDu committed
192
                break
193
194
195
196
197
198
199
        print(
            "In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}".format(
                epoch, train_loss, val_results["mrr"], t1 - t0, t2 - t1
            )
        )

    # test use the best model
KounianhuaDu's avatar
KounianhuaDu committed
200
    compgcn_model.eval()
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
    compgcn_model.load_state_dict(th.load("comp_link" + "_" + args.dataset))
    test_results = evaluate(
        compgcn_model, graph, device, data_iter, split="test"
    )
    print(
        "Test MRR: {:.5}\n, MR: {:.10}\n, H@10: {:.5}\n, H@3: {:.5}\n, H@1: {:.5}\n".format(
            test_results["mrr"],
            test_results["mr"],
            test_results["hits@10"],
            test_results["hits@3"],
            test_results["hits@1"],
        )
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Parser For Arguments",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "--data",
        dest="dataset",
        default="FB15k-237",
        help="Dataset to use, default: FB15k-237",
    )
    parser.add_argument(
        "--model", dest="model", default="compgcn", help="Model Name"
    )
    parser.add_argument(
        "--score_func",
        dest="score_func",
        default="conve",
        help="Score Function for Link prediction",
    )
    parser.add_argument(
        "--opn",
        dest="opn",
        default="ccorr",
        help="Composition Operation to be used in CompGCN",
    )

    parser.add_argument(
        "--batch", dest="batch_size", default=1024, type=int, help="Batch size"
    )
    parser.add_argument(
        "--gpu",
        type=int,
        default="0",
        help="Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0",
    )
    parser.add_argument(
        "--epoch",
        dest="max_epochs",
        type=int,
        default=500,
        help="Number of epochs",
    )
    parser.add_argument(
        "--l2", type=float, default=0.0, help="L2 Regularization for Optimizer"
    )
    parser.add_argument(
        "--lr", type=float, default=0.001, help="Starting Learning Rate"
    )
    parser.add_argument(
        "--lbl_smooth",
        dest="lbl_smooth",
        type=float,
        default=0.1,
        help="Label Smoothing",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=10,
        help="Number of processes to construct batches",
    )
    parser.add_argument(
        "--seed",
        dest="seed",
        default=41504,
        type=int,
        help="Seed for randomization",
    )

    parser.add_argument(
        "--num_bases",
        dest="num_bases",
        default=-1,
        type=int,
        help="Number of basis relation vectors to use",
    )
    parser.add_argument(
        "--init_dim",
        dest="init_dim",
        default=100,
        type=int,
        help="Initial dimension size for entities and relations",
    )
    parser.add_argument(
        "--layer_size",
        nargs="?",
        default="[200]",
        help="List of output size for each compGCN layer",
    )
    parser.add_argument(
        "--gcn_drop",
        dest="dropout",
        default=0.1,
        type=float,
        help="Dropout to use in GCN Layer",
    )
    parser.add_argument(
        "--layer_dropout",
        nargs="?",
        default="[0.3]",
        help="List of dropout value after each compGCN layer",
    )
KounianhuaDu's avatar
KounianhuaDu committed
320
321

    # ConvE specific hyperparameters
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    parser.add_argument(
        "--hid_drop",
        dest="hid_drop",
        default=0.3,
        type=float,
        help="ConvE: Hidden dropout",
    )
    parser.add_argument(
        "--feat_drop",
        dest="feat_drop",
        default=0.3,
        type=float,
        help="ConvE: Feature Dropout",
    )
    parser.add_argument(
        "--k_w", dest="k_w", default=10, type=int, help="ConvE: k_w"
    )
    parser.add_argument(
        "--k_h", dest="k_h", default=20, type=int, help="ConvE: k_h"
    )
    parser.add_argument(
        "--num_filt",
        dest="num_filt",
        default=200,
        type=int,
        help="ConvE: Number of filters in convolution",
    )
    parser.add_argument(
        "--ker_sz",
        dest="ker_sz",
        default=7,
        type=int,
        help="ConvE: Kernel size to use",
    )
KounianhuaDu's avatar
KounianhuaDu committed
356
357

    args = parser.parse_args()
358

KounianhuaDu's avatar
KounianhuaDu committed
359
360
361
362
363
364
365
366
367
    np.random.seed(args.seed)
    th.manual_seed(args.seed)

    print(args)

    args.layer_size = eval(args.layer_size)
    args.layer_dropout = eval(args.layer_dropout)

    main(args)