entity.py 4.11 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
import argparse

import dgl
4
5
import torch
import torch.nn as nn
Mufei Li's avatar
Mufei Li committed
6
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
8
from dgl.nn.pytorch import RelGraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
10
from torchmetrics.functional import accuracy

Mufei Li's avatar
Mufei Li committed
11

12
13
14
15
16
class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
        self.conv1 = RelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )
        self.conv2 = RelGraphConv(
            h_dim,
            out_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )
33

34
35
    def forward(self, g):
        x = self.emb.weight
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
36
37
        h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata["norm"]))
        h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata["norm"])
38
        return h
39

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
40

41
def evaluate(g, target_idx, labels, num_classes, test_mask, model):
42
43
44
45
46
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
    model.eval()
    with torch.no_grad():
        logits = model(g)
    logits = logits[target_idx]
47
48
49
50
51
52
    return accuracy(
        logits[test_idx].argmax(dim=1),
        labels[test_idx],
        task="multiclass",
        num_classes=num_classes,
    ).item()
Mufei Li's avatar
Mufei Li committed
53

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
54

55
def train(g, target_idx, labels, num_classes, train_mask, model):
56
57
58
59
    # define train idx, loss function and optimizer
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
Mufei Li's avatar
Mufei Li committed
60
61

    model.train()
62
    for epoch in range(50):
63
        logits = model(g)
Mufei Li's avatar
Mufei Li committed
64
        logits = logits[target_idx]
65
        loss = loss_fcn(logits[train_idx], labels[train_idx])
Mufei Li's avatar
Mufei Li committed
66
67
68
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
69
        acc = accuracy(
70
71
72
73
            logits[train_idx].argmax(dim=1),
            labels[train_idx],
            task="multiclass",
            num_classes=num_classes,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        ).item()
        print(
            "Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} ".format(
                epoch, loss.item(), acc
            )
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="RGCN for entity classification"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="aifb",
        help="Dataset name ('aifb', 'mutag', 'bgs', 'am').",
    )
Mufei Li's avatar
Mufei Li committed
92
    args = parser.parse_args()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
93
94
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training with DGL built-in RGCN module.")
95
96

    # load and preprocess dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
97
    if args.dataset == "aifb":
98
        data = AIFBDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
99
    elif args.dataset == "mutag":
100
        data = MUTAGDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
101
    elif args.dataset == "bgs":
102
        data = BGSDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
103
    elif args.dataset == "am":
104
105
        data = AMDataset()
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
106
        raise ValueError("Unknown dataset: {}".format(args.dataset))
107
108
109
110
    g = data[0]
    g = g.int().to(device)
    num_rels = len(g.canonical_etypes)
    category = data.predict_category
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
111
112
113
    labels = g.nodes[category].data.pop("labels")
    train_mask = g.nodes[category].data.pop("train_mask")
    test_mask = g.nodes[category].data.pop("test_mask")
114
115
    # calculate normalization weight for each edge, and find target category and node id
    for cetype in g.canonical_etypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
116
        g.edges[cetype].data["norm"] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
117
    category_id = g.ntypes.index(category)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
118
    g = dgl.to_homogeneous(g, edata=["norm"])
119
120
    node_ids = torch.arange(g.num_nodes()).to(device)
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
122
    # create RGCN model
    in_size = g.num_nodes()  # featureless with one-hot encoding
123
124
    num_classes = data.num_classes
    model = RGCN(in_size, 16, num_classes, num_rels).to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
125

126
127
    train(g, target_idx, labels, num_classes, train_mask, model)
    acc = evaluate(g, target_idx, labels, num_classes, test_mask, model)
128
    print("Test accuracy {:.4f}".format(acc))