test_classify.py 2.79 KB
Newer Older
1
2
3
4
5
"""Infering Relational Data with Graph Convolutional Networks
"""
import argparse
from functools import partial

6
7
import torch as th
import torch.nn.functional as F
8
9
from entity_classify import EntityClassify

10
11
12
from dgl.data.rdf import AIFB, AM, BGS, MUTAG


13
14
def main(args):
    # load graph data
15
    if args.dataset == "aifb":
16
        dataset = AIFBDataset()
17
    elif args.dataset == "mutag":
18
        dataset = MUTAGDataset()
19
    elif args.dataset == "bgs":
20
        dataset = BGSDataset()
21
    elif args.dataset == "am":
22
        dataset = AMDataset()
23
24
25
    else:
        raise ValueError()

26
    g = dataset[0]
27
28
    category = dataset.predict_category
    num_classes = dataset.num_classes
29
    test_mask = g.nodes[category].data.pop("test_mask")
30
    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
31
    labels = g.nodes[category].data.pop("labels")
32
33
34
35
36
37
38

    # check cuda
    use_cuda = args.gpu >= 0 and th.cuda.is_available()
    if use_cuda:
        th.cuda.set_device(args.gpu)
        labels = labels.cuda()
        test_idx = test_idx.cuda()
39
        g = g.to("cuda:%d" % args.gpu)
40
41

    # create model
42
43
44
45
46
47
48
49
    model = EntityClassify(
        g,
        args.n_hidden,
        num_classes,
        num_bases=args.n_bases,
        num_hidden_layers=args.n_layers - 2,
        use_self_loop=args.use_self_loop,
    )
50
51
52
53
54
55
    model.load_state_dict(th.load(args.model_path))
    if use_cuda:
        model.cuda()

    print("start testing...")
    model.eval()
56
    logits = model.forward()[category]
57
    test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
58
59
60
61
62
63
64
65
    test_acc = th.sum(
        logits[test_idx].argmax(dim=1) == labels[test_idx]
    ).item() / len(test_idx)
    print(
        "Test Acc: {:.4f} | Test loss: {:.4f}".format(
            test_acc, test_loss.item()
        )
    )
66
    print()
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGCN")
    parser.add_argument(
        "--n-hidden", type=int, default=16, help="number of hidden units"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    parser.add_argument(
        "--n-bases",
        type=int,
        default=-1,
        help="number of filter weight matrices, default: -1 [use all]",
    )
    parser.add_argument(
        "--n-layers", type=int, default=2, help="number of propagation rounds"
    )
    parser.add_argument(
        "-d", "--dataset", type=str, required=True, help="dataset to use"
    )
    parser.add_argument(
        "--model_path", type=str, help="path of the model to load from"
    )
    parser.add_argument(
        "--use-self-loop",
        default=False,
        action="store_true",
        help="include self feature as a special relation",
    )
97
98
99

    args = parser.parse_args()
    print(args)
100
    main(args)