"docs/vscode:/vscode.git/clone" did not exist on "defa292bc07cb61ad77d3589dec746c289c41426"
entity_sample.py 5.68 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
import argparse

import dgl
4
5
import torch
import torch.nn as nn
Mufei Li's avatar
Mufei Li committed
6
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
8
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from dgl.dataloading import DataLoader, MultiLayerNeighborSampler
9
from dgl.nn.pytorch import RelGraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
11
from torchmetrics.functional import accuracy

Mufei Li's avatar
Mufei Li committed
12

13
14
15
16
17
class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, h_dim)
        # two-layer RGCN
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
        self.conv1 = RelGraphConv(
            h_dim,
            h_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )
        self.conv2 = RelGraphConv(
            h_dim,
            out_dim,
            num_rels,
            regularizer="basis",
            num_bases=num_rels,
            self_loop=False,
        )
34

35
36
    def forward(self, g):
        x = self.emb(g[0].srcdata[dgl.NID])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
37
38
39
40
        h = F.relu(
            self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata["norm"])
        )
        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata["norm"])
41
        return h
42

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
43

44
def evaluate(model, label, dataloader, inv_target):
Mufei Li's avatar
Mufei Li committed
45
46
47
    model.eval()
    eval_logits = []
    eval_seeds = []
48
49
50
51
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in dataloader:
            output_nodes = inv_target[output_nodes]
            for block in blocks:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
52
                block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
53
            logits = model(blocks)
Mufei Li's avatar
Mufei Li committed
54
            eval_logits.append(logits.cpu().detach())
55
56
57
            eval_seeds.append(output_nodes.cpu().detach())
    eval_logits = torch.cat(eval_logits)
    eval_seeds = torch.cat(eval_seeds)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
58
59
    return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item()

60
61
62
63
64
65
66
67

def train(device, g, target_idx, labels, train_mask, model):
    # define train idx, loss function and optimizer
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # construct sampler and dataloader
    sampler = MultiLayerNeighborSampler([4, 4])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
68
69
70
71
72
73
74
75
    train_loader = DataLoader(
        g,
        target_idx[train_idx],
        sampler,
        device=device,
        batch_size=100,
        shuffle=True,
    )
76
    # no separate validation subset, use train index instead for validation
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
77
78
79
80
81
82
83
84
    val_loader = DataLoader(
        g,
        target_idx[train_idx],
        sampler,
        device=device,
        batch_size=100,
        shuffle=False,
    )
85
86
87
88
89
90
    for epoch in range(50):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
            output_nodes = inv_target[output_nodes]
            for block in blocks:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
91
                block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
92
93
94
95
96
97
98
            logits = model(blocks)
            loss = loss_fcn(logits, labels[output_nodes])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        acc = evaluate(model, labels, val_loader, inv_target)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
99
100
101
102
103
104
        print(
            "Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} ".format(
                epoch, total_loss / (it + 1), acc
            )
        )

105

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
106
107
108
109
110
111
112
113
114
115
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="RGCN for entity classification with sampling"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="aifb",
        help="Dataset name ('aifb', 'mutag', 'bgs', 'am').",
    )
Mufei Li's avatar
Mufei Li committed
116
    args = parser.parse_args()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
117
118
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training with DGL built-in RGCN module with sampling.")
119
120

    # load and preprocess dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
    if args.dataset == "aifb":
122
        data = AIFBDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
123
    elif args.dataset == "mutag":
124
        data = MUTAGDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
125
    elif args.dataset == "bgs":
126
        data = BGSDataset()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
127
    elif args.dataset == "am":
128
129
        data = AMDataset()
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
130
        raise ValueError("Unknown dataset: {}".format(args.dataset))
131
132
133
    g = data[0]
    num_rels = len(g.canonical_etypes)
    category = data.predict_category
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
134
135
136
    labels = g.nodes[category].data.pop("labels").to(device)
    train_mask = g.nodes[category].data.pop("train_mask")
    test_mask = g.nodes[category].data.pop("test_mask")
137
138
139
140
141
142
    # find target category and node id
    category_id = g.ntypes.index(category)
    g = dgl.to_homogeneous(g)
    node_ids = torch.arange(g.num_nodes())
    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
    # rename the fields as they can be changed by DataLoader
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
143
144
    g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
    g.ndata["type_id"] = g.ndata.pop(dgl.NID)
145
146
    # find the mapping (inv_target) from global node IDs to type-specific node IDs
    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
147
148
149
    inv_target[target_idx] = torch.arange(
        0, target_idx.shape[0], dtype=inv_target.dtype
    ).to(device)
150
151

    # create RGCN model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
152
    in_size = g.num_nodes()  # featureless with one-hot encoding
153
154
    out_size = data.num_classes
    model = RGCN(in_size, 16, out_size, num_rels).to(device)
155

156
157
    train(device, g, target_idx, labels, train_mask, model)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
158
159
160
161
162
163
164
165
166
167
168
    test_sampler = MultiLayerNeighborSampler(
        [-1, -1]
    )  # -1 for sampling all neighbors
    test_loader = DataLoader(
        g,
        target_idx[test_idx],
        test_sampler,
        device=device,
        batch_size=32,
        shuffle=False,
    )
169
170
    acc = evaluate(model, labels, test_loader, inv_target)
    print("Test accuracy {:.4f}".format(acc))