rgcn.py 6.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
[RGCN: Relational Graph Convolutional Networks]
(https://arxiv.org/abs/1703.06103)

This example showcases the usage of `CuGraphRelGraphConv` via the entity
classification problem in the RGCN paper with mini-batch training. It offers
a 1.5~2x speed-up over `RelGraphConv` on cuda devices and only requires minimal
code changes from the current `entity_sample.py` example.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from dgl.nn import CuGraphRelGraphConv
import argparse


class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, fanouts):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
        self.conv1 = CuGraphRelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_bases,
            self_loop=False,
            max_in_degree=fanouts[0]
        )
        self.conv2 = CuGraphRelGraphConv(
            h_dim,
            out_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_bases,
            self_loop=False,
            max_in_degree=fanouts[1]
        )

    def forward(self, g):
        x = self.emb(g[0].srcdata[dgl.NID])
        h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE],
                   norm=g[0].edata["norm"]))
        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], norm=g[1].edata["norm"])
        return h

    def update_max_in_degree(self, fanouts):
        self.conv1.max_in_degree = fanouts[0]
        self.conv2.max_in_degree = fanouts[1]


def evaluate(model, labels, dataloader, inv_target):
    model.eval()
    eval_logits = []
    eval_seeds = []
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in dataloader:
            output_nodes = inv_target[output_nodes.type(torch.int64)]
            for block in blocks:
                block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
            logits = model(blocks)
            eval_logits.append(logits.cpu().detach())
            eval_seeds.append(output_nodes.cpu().detach())
    eval_logits = torch.cat(eval_logits)
    eval_seeds = torch.cat(eval_seeds)
    return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item()


def train(device, g, target_idx, labels, train_mask, model, fanouts):
    # Define train idx, loss function and optimizer.
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # Construct sampler and dataloader.
    sampler = MultiLayerNeighborSampler(fanouts)
    train_loader = DataLoader(
        g,
        target_idx[train_idx].type(g.idtype),
        sampler,
        device=device,
        batch_size=100,
        shuffle=True,
    )
    # No separate validation subset, use train index instead for validation.
    val_loader = DataLoader(
        g,
        target_idx[train_idx].type(g.idtype),
        sampler,
        device=device,
        batch_size=100,
        shuffle=False,
    )
    for epoch in range(100):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
            output_nodes = inv_target[output_nodes.type(torch.int64)]
            for block in blocks:
                block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
            logits = model(blocks)
            loss = loss_fcn(logits, labels[output_nodes])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        acc = evaluate(model, labels, val_loader, inv_target)
        print(
            f"Epoch {epoch:05d} | Loss {total_loss / (it+1):.4f} | "
            f"Val. Accuracy {acc:.4f}"
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="RGCN for entity classification with sampling"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="aifb",
        choices=['aifb', 'mutag', 'bgs', 'am'],
    )
    args = parser.parse_args()
    device = torch.device("cuda")
    print(f"Training with DGL CuGraphRelGraphConv module with sampling.")

    # Load and preprocess dataset.
    if args.dataset == "aifb":
        data = AIFBDataset()
    elif args.dataset == "mutag":
        data = MUTAGDataset()
    elif args.dataset == "bgs":
        data = BGSDataset()
    elif args.dataset == "am":
        data = AMDataset()
    else:
        raise ValueError(f"Unknown dataset: {args.dataset}")
    hg = data[0].to(device)
    num_rels = len(hg.canonical_etypes)
    category = data.predict_category

    labels = hg.nodes[category].data.pop("labels")
    train_mask = hg.nodes[category].data.pop("train_mask")
    test_mask = hg.nodes[category].data.pop("test_mask")

    # Find target category and node id.
    category_id = hg.ntypes.index(category)
    g = dgl.to_homogeneous(hg)
    node_ids = torch.arange(g.num_nodes()).to(device)
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
    g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
    g.ndata["type_id"] = g.ndata.pop(dgl.NID)

    # Find the mapping from global node IDs to type-specific node IDs.
    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
    inv_target[target_idx] = torch.arange(
        0, target_idx.shape[0], dtype=inv_target.dtype
    ).to(device)

    # Create RGCN model.
    in_size = g.num_nodes()  # featureless with one-hot encoding
    out_size = data.num_classes
    num_bases = 20
    fanouts = [4, 4]
    model = RGCN(in_size, 16, out_size, num_rels, num_bases, fanouts).to(device)

    train(device, g, target_idx, labels, train_mask, model, fanouts)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
    # Note: cugraph-ops aggregators are designed for sampled graphs (MFGs) and
    # expect max_in_degree as input for performance considerations. Hence, we
    # have to update max_in_degree with the fanouts of test_sampler.
    test_sampler = MultiLayerNeighborSampler([500, 500])
    model.update_max_in_degree(test_sampler.fanouts)
    test_loader = DataLoader(
        g,
        target_idx[test_idx].type(g.idtype),
        test_sampler,
        device=device,
        batch_size=32,
        shuffle=False,
    )
    acc = evaluate(model, labels, test_loader, inv_target)
    print(f"Test accuracy {acc:.4f}")