rgcn.py 6.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
[RGCN: Relational Graph Convolutional Networks]
(https://arxiv.org/abs/1703.06103)

This example showcases the usage of `CuGraphRelGraphConv` via the entity
classification problem in the RGCN paper with mini-batch training. It offers
a 1.5~2x speed-up over `RelGraphConv` on cuda devices and only requires minimal
code changes from the current `entity_sample.py` example.
"""

11
12
13
import argparse

import dgl
14
15
16
import torch
import torch.nn as nn
import torch.nn.functional as F
17
18
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from dgl.dataloading import DataLoader, MultiLayerNeighborSampler
19
from dgl.nn import CuGraphRelGraphConv
20
from torchmetrics.functional import accuracy
21
22
23


class RGCN(nn.Module):
24
    def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases):
25
26
27
28
29
30
31
32
33
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
        self.conv1 = CuGraphRelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_bases,
34
35
            self_loop=True,
            apply_norm=True,
36
37
38
39
40
41
42
        )
        self.conv2 = CuGraphRelGraphConv(
            h_dim,
            out_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_bases,
43
44
            self_loop=True,
            apply_norm=True,
45
46
        )

47
    def forward(self, g, fanouts=[None, None]):
48
        x = self.emb(g[0].srcdata[dgl.NID])
49
50
        h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], fanouts[0]))
        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], fanouts[1])
51
52
53
54
55
56
57
58
        return h


def evaluate(model, labels, dataloader, inv_target):
    model.eval()
    eval_logits = []
    eval_seeds = []
    with torch.no_grad():
59
        for _, output_nodes, blocks in dataloader:
60
61
62
63
            output_nodes = inv_target[output_nodes.type(torch.int64)]
            logits = model(blocks)
            eval_logits.append(logits.cpu().detach())
            eval_seeds.append(output_nodes.cpu().detach())
64
    num_classes = eval_logits[0].shape[1]
65
66
    eval_logits = torch.cat(eval_logits)
    eval_seeds = torch.cat(eval_seeds)
67
68
69
70
71
72
    return accuracy(
        eval_logits.argmax(dim=1),
        labels[eval_seeds].cpu(),
        task="multiclass",
        num_classes=num_classes,
    ).item()
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


def train(device, g, target_idx, labels, train_mask, model, fanouts):
    # Define train idx, loss function and optimizer.
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # Construct sampler and dataloader.
    sampler = MultiLayerNeighborSampler(fanouts)
    train_loader = DataLoader(
        g,
        target_idx[train_idx].type(g.idtype),
        sampler,
        device=device,
        batch_size=100,
        shuffle=True,
    )
    # No separate validation subset, use train index instead for validation.
    val_loader = DataLoader(
        g,
        target_idx[train_idx].type(g.idtype),
        sampler,
        device=device,
        batch_size=100,
        shuffle=False,
    )
99
    for epoch in range(50):
100
101
        model.train()
        total_loss = 0
102
        for it, (_, output_nodes, blocks) in enumerate(train_loader):
103
            output_nodes = inv_target[output_nodes.type(torch.int64)]
104
            logits = model(blocks, fanouts=fanouts)
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            loss = loss_fcn(logits, labels[output_nodes])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        acc = evaluate(model, labels, val_loader, inv_target)
        print(
            f"Epoch {epoch:05d} | Loss {total_loss / (it+1):.4f} | "
            f"Val. Accuracy {acc:.4f}"
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="RGCN for entity classification with sampling"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="aifb",
125
        choices=["aifb", "mutag", "bgs", "am"],
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    )
    args = parser.parse_args()
    device = torch.device("cuda")
    print(f"Training with DGL CuGraphRelGraphConv module with sampling.")

    # Load and preprocess dataset.
    if args.dataset == "aifb":
        data = AIFBDataset()
    elif args.dataset == "mutag":
        data = MUTAGDataset()
    elif args.dataset == "bgs":
        data = BGSDataset()
    elif args.dataset == "am":
        data = AMDataset()
    else:
        raise ValueError(f"Unknown dataset: {args.dataset}")
    hg = data[0].to(device)
    num_rels = len(hg.canonical_etypes)
    category = data.predict_category

    labels = hg.nodes[category].data.pop("labels")
    train_mask = hg.nodes[category].data.pop("train_mask")
    test_mask = hg.nodes[category].data.pop("test_mask")

    # Find target category and node id.
    category_id = hg.ntypes.index(category)
    g = dgl.to_homogeneous(hg)
    node_ids = torch.arange(g.num_nodes()).to(device)
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
    g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
    g.ndata["type_id"] = g.ndata.pop(dgl.NID)

    # Find the mapping from global node IDs to type-specific node IDs.
    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
    inv_target[target_idx] = torch.arange(
        0, target_idx.shape[0], dtype=inv_target.dtype
    ).to(device)

    # Create RGCN model.
    in_size = g.num_nodes()  # featureless with one-hot encoding
    out_size = data.num_classes
    num_bases = 20
    fanouts = [4, 4]
169
    model = RGCN(in_size, 16, out_size, num_rels, num_bases).to(device)
170

171
172
173
174
175
176
177
178
179
    train(
        device,
        g,
        target_idx,
        labels,
        train_mask,
        model,
        fanouts,
    )
180
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
181
    test_sampler = MultiLayerNeighborSampler([-1, -1])
182
183
184
185
186
187
188
189
190
191
    test_loader = DataLoader(
        g,
        target_idx[test_idx].type(g.idtype),
        test_sampler,
        device=device,
        batch_size=32,
        shuffle=False,
    )
    acc = evaluate(model, labels, test_loader, inv_target)
    print(f"Test accuracy {acc:.4f}")