entity_sample.py 5.83 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
import argparse

import dgl
4
5
import torch
import torch.nn as nn
Mufei Li's avatar
Mufei Li committed
6
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
8
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from dgl.dataloading import DataLoader, MultiLayerNeighborSampler
9
from dgl.nn.pytorch import RelGraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
11
from torchmetrics.functional import accuracy

Mufei Li's avatar
Mufei Li committed
12

13
14
15
16
17
class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
        self.conv1 = RelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )
        self.conv2 = RelGraphConv(
            h_dim,
            out_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )
34

35
36
    def forward(self, g):
        x = self.emb(g[0].srcdata[dgl.NID])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
37
38
39
40
        h = F.relu(
            self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata["norm"])
        )
        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata["norm"])
41
        return h
42

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
43

44
def evaluate(model, labels, num_classes, dataloader, inv_target):
Mufei Li's avatar
Mufei Li committed
45
46
47
    model.eval()
    eval_logits = []
    eval_seeds = []
48
49
50
51
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in dataloader:
            output_nodes = inv_target[output_nodes]
            for block in blocks:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
52
                block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
53
            logits = model(blocks)
Mufei Li's avatar
Mufei Li committed
54
            eval_logits.append(logits.cpu().detach())
55
56
57
            eval_seeds.append(output_nodes.cpu().detach())
    eval_logits = torch.cat(eval_logits)
    eval_seeds = torch.cat(eval_seeds)
58
59
60
61
62
63
    return accuracy(
        eval_logits.argmax(dim=1),
        labels[eval_seeds].cpu(),
        task="multiclass",
        num_classes=num_classes,
    ).item()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
64

65

66
def train(device, g, target_idx, labels, train_mask, num_classes, model):
67
68
69
70
71
72
    # define train idx, loss function and optimizer
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # construct sampler and dataloader
    sampler = MultiLayerNeighborSampler([4, 4])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
73
74
75
76
77
78
79
80
    train_loader = DataLoader(
        g,
        target_idx[train_idx],
        sampler,
        device=device,
        batch_size=100,
        shuffle=True,
    )
81
    # no separate validation subset, use train index instead for validation
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
82
83
84
85
86
87
88
89
    val_loader = DataLoader(
        g,
        target_idx[train_idx],
        sampler,
        device=device,
        batch_size=100,
        shuffle=False,
    )
90
91
92
93
94
95
    for epoch in range(50):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
            output_nodes = inv_target[output_nodes]
            for block in blocks:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
96
                block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
97
98
99
100
101
102
            logits = model(blocks)
            loss = loss_fcn(logits, labels[output_nodes])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
103
        acc = evaluate(model, labels, num_classes, val_loader, inv_target)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
104
105
106
107
108
109
        print(
            "Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} ".format(
                epoch, total_loss / (it + 1), acc
            )
        )

110

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
111
112
113
114
115
116
117
118
119
120
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="RGCN for entity classification with sampling"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="aifb",
        help="Dataset name ('aifb', 'mutag', 'bgs', 'am').",
    )
Mufei Li's avatar
Mufei Li committed
121
    args = parser.parse_args()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
122
123
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training with DGL built-in RGCN module with sampling.")
124
125

    # load and preprocess dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
126
    if args.dataset == "aifb":
127
        data = AIFBDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
128
    elif args.dataset == "mutag":
129
        data = MUTAGDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
130
    elif args.dataset == "bgs":
131
        data = BGSDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
132
    elif args.dataset == "am":
133
134
        data = AMDataset()
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
135
        raise ValueError("Unknown dataset: {}".format(args.dataset))
136
137
138
    g = data[0]
    num_rels = len(g.canonical_etypes)
    category = data.predict_category
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
139
140
141
    labels = g.nodes[category].data.pop("labels").to(device)
    train_mask = g.nodes[category].data.pop("train_mask")
    test_mask = g.nodes[category].data.pop("test_mask")
142
143
144
145
146
147
    # find target category and node id
    category_id = g.ntypes.index(category)
    g = dgl.to_homogeneous(g)
    node_ids = torch.arange(g.num_nodes())
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
    # rename the fields as they can be changed by DataLoader
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
148
149
    g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
    g.ndata["type_id"] = g.ndata.pop(dgl.NID)
150
151
    # find the mapping (inv_target) from global node IDs to type-specific node IDs
    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
152
153
154
    inv_target[target_idx] = torch.arange(
        0, target_idx.shape[0], dtype=inv_target.dtype
    ).to(device)
155
156

    # create RGCN model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
157
    in_size = g.num_nodes()  # featureless with one-hot encoding
158
159
    num_classes = data.num_classes
    model = RGCN(in_size, 16, num_classes, num_rels).to(device)
160

161
    train(device, g, target_idx, labels, train_mask, num_classes, model)
162
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
163
164
165
166
167
168
169
170
171
172
173
    test_sampler = MultiLayerNeighborSampler(
        [-1, -1]
    )  # -1 for sampling all neighbors
    test_loader = DataLoader(
        g,
        target_idx[test_idx],
        test_sampler,
        device=device,
        batch_size=32,
        shuffle=False,
    )
174
    acc = evaluate(model, labels, num_classes, test_loader, inv_target)
175
    print("Test accuracy {:.4f}".format(acc))