entity.py 3.89 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
import argparse

import dgl
4
5
import torch
import torch.nn as nn
Mufei Li's avatar
Mufei Li committed
6
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
8
from dgl.nn.pytorch import RelGraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
10
from torchmetrics.functional import accuracy

Mufei Li's avatar
Mufei Li committed
11

12
13
14
15
16
class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
        self.conv1 = RelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )
        self.conv2 = RelGraphConv(
            h_dim,
            out_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )
33

34
35
    def forward(self, g):
        x = self.emb.weight
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
36
37
        h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata["norm"]))
        h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata["norm"])
38
        return h
39

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
40

41
42
43
44
45
46
47
def evaluate(g, target_idx, labels, test_mask, model):
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
    model.eval()
    with torch.no_grad():
        logits = model(g)
    logits = logits[target_idx]
    return accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
Mufei Li's avatar
Mufei Li committed
48

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
49

50
51
52
53
54
def train(g, target_idx, labels, train_mask, model):
    # define train idx, loss function and optimizer
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
Mufei Li's avatar
Mufei Li committed
55
56

    model.train()
57
    for epoch in range(50):
58
        logits = model(g)
Mufei Li's avatar
Mufei Li committed
59
        logits = logits[target_idx]
60
        loss = loss_fcn(logits[train_idx], labels[train_idx])
Mufei Li's avatar
Mufei Li committed
61
62
63
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        acc = accuracy(
            logits[train_idx].argmax(dim=1), labels[train_idx]
        ).item()
        print(
            "Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} ".format(
                epoch, loss.item(), acc
            )
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="RGCN for entity classification"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="aifb",
        help="Dataset name ('aifb', 'mutag', 'bgs', 'am').",
    )
Mufei Li's avatar
Mufei Li committed
84
    args = parser.parse_args()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
85
86
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training with DGL built-in RGCN module.")
87
88

    # load and preprocess dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
89
    if args.dataset == "aifb":
90
        data = AIFBDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
91
    elif args.dataset == "mutag":
92
        data = MUTAGDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
93
    elif args.dataset == "bgs":
94
        data = BGSDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
95
    elif args.dataset == "am":
96
97
        data = AMDataset()
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
98
        raise ValueError("Unknown dataset: {}".format(args.dataset))
99
100
101
102
    g = data[0]
    g = g.int().to(device)
    num_rels = len(g.canonical_etypes)
    category = data.predict_category
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
103
104
105
    labels = g.nodes[category].data.pop("labels")
    train_mask = g.nodes[category].data.pop("train_mask")
    test_mask = g.nodes[category].data.pop("test_mask")
106
107
    # calculate normalization weight for each edge, and find target category and node id
    for cetype in g.canonical_etypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
108
        g.edges[cetype].data["norm"] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
109
    category_id = g.ntypes.index(category)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
110
    g = dgl.to_homogeneous(g, edata=["norm"])
111
112
    node_ids = torch.arange(g.num_nodes()).to(device)
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
113
114
    # create RGCN model
    in_size = g.num_nodes()  # featureless with one-hot encoding
115
116
    out_size = data.num_classes
    model = RGCN(in_size, 16, out_size, num_rels).to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
117

118
119
120
    train(g, target_idx, labels, train_mask, model)
    acc = evaluate(g, target_idx, labels, test_mask, model)
    print("Test accuracy {:.4f}".format(acc))