entity.py 2.33 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
"""
Differences compared to tkipf/relation-gcn
3
* weight decay applied to all weights
Mufei Li's avatar
Mufei Li committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
import argparse
import torch as th
import torch.nn.functional as F

from torchmetrics.functional import accuracy

from entity_utils import load_data
from model import RGCN

def main(args):
    g, num_rels, num_classes, labels, train_idx, test_idx, target_idx = load_data(
        args.dataset, get_norm=True)

18
    model = RGCN(g.num_nodes(),
Mufei Li's avatar
Mufei Li committed
19
20
21
22
23
24
25
26
27
28
29
                 args.n_hidden,
                 num_classes,
                 num_rels,
                 num_bases=args.n_bases)

    if args.gpu >= 0 and th.cuda.is_available():
        device = th.device(args.gpu)
    else:
        device = th.device('cpu')
    labels = labels.to(device)
    model = model.to(device)
30
    g = g.int().to(device)
Mufei Li's avatar
Mufei Li committed
31

32
    optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
Mufei Li's avatar
Mufei Li committed
33
34

    model.train()
35
36
    for epoch in range(100):
        logits = model(g)
Mufei Li's avatar
Mufei Li committed
37
38
39
40
41
42
43
44
45
46
47
48
49
        logits = logits[target_idx]
        loss = F.cross_entropy(logits[train_idx], labels[train_idx])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_acc = accuracy(logits[train_idx].argmax(dim=1), labels[train_idx]).item()
        print("Epoch {:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format(
            epoch, train_acc, loss.item()))
    print()

    model.eval()
    with th.no_grad():
50
        logits = model(g)
Mufei Li's avatar
Mufei Li committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    logits = logits[target_idx]
    test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
    print("Test Accuracy: {:.4f}".format(test_acc))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN for entity classification')
    parser.add_argument("--n-hidden", type=int, default=16,
                        help="number of hidden units")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--n-bases", type=int, default=-1,
                        help="number of filter weight matrices, default: -1 [use all]")
    parser.add_argument("-d", "--dataset", type=str, required=True,
                        choices=['aifb', 'mutag', 'bgs', 'am'],
                        help="dataset to use")
66
67
    parser.add_argument("--wd", type=float, default=5e-4,
                        help="weight decay")
Mufei Li's avatar
Mufei Li committed
68
69
70
71

    args = parser.parse_args()
    print(args)
    main(args)