test_classify.py 2.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
"""Infering Relational Data with Graph Convolutional Networks
"""
import argparse
import torch as th
from functools import partial
import torch.nn.functional as F

from dgl.data.rdf import AIFB, MUTAG, BGS, AM
from entity_classify import EntityClassify

def main(args):
    # load graph data
    if args.dataset == 'aifb':
14
        dataset = AIFBDataset()
15
    elif args.dataset == 'mutag':
16
        dataset = MUTAGDataset()
17
    elif args.dataset == 'bgs':
18
        dataset = BGSDataset()
19
    elif args.dataset == 'am':
20
        dataset = AMDataset()
21
22
23
    else:
        raise ValueError()

24
    g = dataset[0]
25
26
    category = dataset.predict_category
    num_classes = dataset.num_classes
27
    test_mask = g.nodes[category].data.pop('test_mask')
28
    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
29
    labels = g.nodes[category].data.pop('labels')
30
31
32
33
34
35
36

    # check cuda
    use_cuda = args.gpu >= 0 and th.cuda.is_available()
    if use_cuda:
        th.cuda.set_device(args.gpu)
        labels = labels.cuda()
        test_idx = test_idx.cuda()
37
        g = g.to('cuda:%d' % args.gpu)
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    # create model
    model = EntityClassify(g,
                           args.n_hidden,
                           num_classes,
                           num_bases=args.n_bases,
                           num_hidden_layers=args.n_layers - 2,
                           use_self_loop=args.use_self_loop)
    model.load_state_dict(th.load(args.model_path))
    if use_cuda:
        model.cuda()

    print("start testing...")
    model.eval()
52
    logits = model.forward()[category]
53
54
55
56
    test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
    test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
    print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
    print()
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN')
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden units")
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2,
            help="learning rate")
    parser.add_argument("--n-bases", type=int, default=-1,
            help="number of filter weight matrices, default: -1 [use all]")
    parser.add_argument("--n-layers", type=int, default=2,
            help="number of propagation rounds")
    parser.add_argument("-d", "--dataset", type=str, required=True,
            help="dataset to use")
    parser.add_argument("--model_path", type=str,
            help='path of the model to load from')
    parser.add_argument("--use-self-loop", default=False, action='store_true',
            help="include self feature as a special relation")

    args = parser.parse_args()
    print(args)
79
    main(args)