entity_sample.py 5.62 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
"""
Differences compared to tkipf/relation-gcn
3
* weight decay applied to all weights
Mufei Li's avatar
Mufei Li committed
4
5
6
7
8
9
10
11
12
13
14
15
* remove nodes that won't be touched
"""
import argparse
import torch as th
import torch.nn.functional as F
import dgl

from dgl.dataloading import MultiLayerNeighborSampler, NodeDataLoader
from torchmetrics.functional import accuracy
from tqdm import tqdm

from entity_utils import load_data
16
from model import RGCN
Mufei Li's avatar
Mufei Li committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=False):
    fanouts = [int(fanout) for fanout in args.fanout.split(',')]
    sampler = MultiLayerNeighborSampler(fanouts)

    train_loader = NodeDataLoader(
        g,
        target_idx[train_idx],
        sampler,
        use_ddp=use_ddp,
        device=device,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False)

    # The datasets do not have a validation subset, use the train subset
    val_loader = NodeDataLoader(
        g,
        target_idx[train_idx],
        sampler,
        use_ddp=use_ddp,
        device=device,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False)

    # -1 for sampling all neighbors
    test_sampler = MultiLayerNeighborSampler([-1] * len(fanouts))
    test_loader = NodeDataLoader(
        g,
        target_idx[test_idx],
        test_sampler,
        use_ddp=use_ddp,
        device=device,
        batch_size=32,
        shuffle=False,
        drop_last=False)

    return train_loader, val_loader, test_loader

def process_batch(inv_target, batch):
    _, seeds, blocks = batch
    # map the seed nodes back to their type-specific ids,
    # in order to get the target node labels
    seeds = inv_target[seeds]

    for blc in blocks:
        blc.edata['norm'] = dgl.norm_by_dst(blc).unsqueeze(1)

    return seeds, blocks

68
69
def train(model, train_loader, inv_target,
          labels, optimizer):
Mufei Li's avatar
Mufei Li committed
70
71
72
73
    model.train()

    for sample_data in train_loader:
        seeds, blocks = process_batch(inv_target, sample_data)
74
        logits = model.forward(blocks)
Mufei Li's avatar
Mufei Li committed
75
76
        loss = F.cross_entropy(logits, labels[seeds])

77
        optimizer.zero_grad()
Mufei Li's avatar
Mufei Li committed
78
79
80
81
82
83
84
        loss.backward()
        optimizer.step()

        train_acc = accuracy(logits.argmax(dim=1), labels[seeds]).item()

    return train_acc, loss.item()

85
def evaluate(model, eval_loader, inv_target):
Mufei Li's avatar
Mufei Li committed
86
87
88
89
90
91
92
    model.eval()
    eval_logits = []
    eval_seeds = []

    with th.no_grad():
        for sample_data in tqdm(eval_loader):
            seeds, blocks = process_batch(inv_target, sample_data)
93
            logits = model.forward(blocks)
Mufei Li's avatar
Mufei Li committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            eval_logits.append(logits.cpu().detach())
            eval_seeds.append(seeds.cpu().detach())

    eval_logits = th.cat(eval_logits)
    eval_seeds = th.cat(eval_seeds)

    return eval_logits, eval_seeds

def main(args):
    g, num_rels, num_classes, labels, train_idx, test_idx, target_idx, inv_target = load_data(
        args.dataset, inv_target=True)

    if args.gpu >= 0 and th.cuda.is_available():
        device = th.device(args.gpu)
    else:
        device = th.device('cpu')

    train_loader, val_loader, test_loader = init_dataloaders(
        args, g, train_idx, test_idx, target_idx, args.gpu)

114
115
116
117
118
119
120
121
    model = RGCN(g.num_nodes(),
                 args.n_hidden,
                 num_classes,
                 num_rels,
                 num_bases=args.n_bases,
                 dropout=args.dropout,
                 self_loop=args.use_self_loop,
                 ns_mode=True)
Mufei Li's avatar
Mufei Li committed
122
123
124
    labels = labels.to(device)
    model = model.to(device)

125
    optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
Mufei Li's avatar
Mufei Li committed
126
127

    for epoch in range(args.n_epochs):
128
        train_acc, loss = train(model, train_loader, inv_target, labels, optimizer)
Mufei Li's avatar
Mufei Li committed
129
130
131
        print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format(
            epoch, args.n_epochs, train_acc, loss))

132
        val_logits, val_seeds = evaluate(model, val_loader, inv_target)
Mufei Li's avatar
Mufei Li committed
133
134
135
        val_acc = accuracy(val_logits.argmax(dim=1), labels[val_seeds].cpu()).item()
        print("Validation Accuracy: {:.4f}".format(val_acc))

136
    test_logits, test_seeds = evaluate(model, test_loader, inv_target)
Mufei Li's avatar
Mufei Li committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    test_acc = accuracy(test_logits.argmax(dim=1), labels[test_seeds].cpu()).item()
    print("Final Test Accuracy: {:.4f}".format(test_acc))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling')
    parser.add_argument("--dropout", type=float, default=0,
                        help="dropout probability")
    parser.add_argument("--n-hidden", type=int, default=16,
                        help="number of hidden units")
    parser.add_argument("--gpu", type=int, default=0,
                        help="gpu")
    parser.add_argument("--n-bases", type=int, default=-1,
                        help="number of filter weight matrices, default: -1 [use all]")
    parser.add_argument("--n-epochs", type=int, default=50,
                        help="number of training epochs")
    parser.add_argument("-d", "--dataset", type=str, required=True,
                        choices=['aifb', 'mutag', 'bgs', 'am'],
                        help="dataset to use")
155
156
    parser.add_argument("--wd", type=float, default=5e-4,
                        help="weight decay")
Mufei Li's avatar
Mufei Li committed
157
158
159
160
161
162
163
164
165
166
    parser.add_argument("--fanout", type=str, default="4, 4",
                        help="Fan-out of neighbor sampling")
    parser.add_argument("--use-self-loop", default=False, action='store_true',
                        help="include self feature as a special relation")
    parser.add_argument("--batch-size", type=int, default=100,
                        help="Mini-batch size")
    args = parser.parse_args()

    print(args)
    main(args)