entity_sample.py 5.6 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
"""
Differences compared to tkipf/relation-gcn
3
* weight decay applied to all weights
Mufei Li's avatar
Mufei Li committed
4
5
6
7
8
9
10
* remove nodes that won't be touched
"""
import argparse
import torch as th
import torch.nn.functional as F
import dgl

11
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
Mufei Li's avatar
Mufei Li committed
12
13
14
15
from torchmetrics.functional import accuracy
from tqdm import tqdm

from entity_utils import load_data
16
from model import RGCN
Mufei Li's avatar
Mufei Li committed
17
18
19
20
21

def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=False):
    fanouts = [int(fanout) for fanout in args.fanout.split(',')]
    sampler = MultiLayerNeighborSampler(fanouts)

22
    train_loader = DataLoader(
Mufei Li's avatar
Mufei Li committed
23
24
25
26
27
28
29
30
31
32
        g,
        target_idx[train_idx],
        sampler,
        use_ddp=use_ddp,
        device=device,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False)

    # The datasets do not have a validation subset, use the train subset
33
    val_loader = DataLoader(
Mufei Li's avatar
Mufei Li committed
34
35
36
37
38
39
40
41
42
43
44
        g,
        target_idx[train_idx],
        sampler,
        use_ddp=use_ddp,
        device=device,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False)

    # -1 for sampling all neighbors
    test_sampler = MultiLayerNeighborSampler([-1] * len(fanouts))
45
    test_loader = DataLoader(
Mufei Li's avatar
Mufei Li committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        g,
        target_idx[test_idx],
        test_sampler,
        use_ddp=use_ddp,
        device=device,
        batch_size=32,
        shuffle=False,
        drop_last=False)

    return train_loader, val_loader, test_loader

def process_batch(inv_target, batch):
    _, seeds, blocks = batch
    # map the seed nodes back to their type-specific ids,
    # in order to get the target node labels
    seeds = inv_target[seeds]

    for blc in blocks:
        blc.edata['norm'] = dgl.norm_by_dst(blc).unsqueeze(1)

    return seeds, blocks

68
69
def train(model, train_loader, inv_target,
          labels, optimizer):
Mufei Li's avatar
Mufei Li committed
70
71
72
73
    model.train()

    for sample_data in train_loader:
        seeds, blocks = process_batch(inv_target, sample_data)
74
        logits = model.forward(blocks)
Mufei Li's avatar
Mufei Li committed
75
76
        loss = F.cross_entropy(logits, labels[seeds])

77
        optimizer.zero_grad()
Mufei Li's avatar
Mufei Li committed
78
79
80
81
82
83
84
        loss.backward()
        optimizer.step()

        train_acc = accuracy(logits.argmax(dim=1), labels[seeds]).item()

    return train_acc, loss.item()

85
def evaluate(model, eval_loader, inv_target):
Mufei Li's avatar
Mufei Li committed
86
87
88
89
90
91
92
    model.eval()
    eval_logits = []
    eval_seeds = []

    with th.no_grad():
        for sample_data in tqdm(eval_loader):
            seeds, blocks = process_batch(inv_target, sample_data)
93
            logits = model.forward(blocks)
Mufei Li's avatar
Mufei Li committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            eval_logits.append(logits.cpu().detach())
            eval_seeds.append(seeds.cpu().detach())

    eval_logits = th.cat(eval_logits)
    eval_seeds = th.cat(eval_seeds)

    return eval_logits, eval_seeds

def main(args):
    g, num_rels, num_classes, labels, train_idx, test_idx, target_idx, inv_target = load_data(
        args.dataset, inv_target=True)

    if args.gpu >= 0 and th.cuda.is_available():
        device = th.device(args.gpu)
    else:
        device = th.device('cpu')

    train_loader, val_loader, test_loader = init_dataloaders(
        args, g, train_idx, test_idx, target_idx, args.gpu)

114
115
116
117
118
119
120
121
    model = RGCN(g.num_nodes(),
                 args.n_hidden,
                 num_classes,
                 num_rels,
                 num_bases=args.n_bases,
                 dropout=args.dropout,
                 self_loop=args.use_self_loop,
                 ns_mode=True)
Mufei Li's avatar
Mufei Li committed
122
123
124
    labels = labels.to(device)
    model = model.to(device)

125
    optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
Mufei Li's avatar
Mufei Li committed
126
127

    for epoch in range(args.n_epochs):
128
        train_acc, loss = train(model, train_loader, inv_target, labels, optimizer)
Mufei Li's avatar
Mufei Li committed
129
130
131
        print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format(
            epoch, args.n_epochs, train_acc, loss))

132
        val_logits, val_seeds = evaluate(model, val_loader, inv_target)
Mufei Li's avatar
Mufei Li committed
133
134
135
        val_acc = accuracy(val_logits.argmax(dim=1), labels[val_seeds].cpu()).item()
        print("Validation Accuracy: {:.4f}".format(val_acc))

136
    test_logits, test_seeds = evaluate(model, test_loader, inv_target)
Mufei Li's avatar
Mufei Li committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    test_acc = accuracy(test_logits.argmax(dim=1), labels[test_seeds].cpu()).item()
    print("Final Test Accuracy: {:.4f}".format(test_acc))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling')
    parser.add_argument("--dropout", type=float, default=0,
                        help="dropout probability")
    parser.add_argument("--n-hidden", type=int, default=16,
                        help="number of hidden units")
    parser.add_argument("--gpu", type=int, default=0,
                        help="gpu")
    parser.add_argument("--n-bases", type=int, default=-1,
                        help="number of filter weight matrices, default: -1 [use all]")
    parser.add_argument("--n-epochs", type=int, default=50,
                        help="number of training epochs")
    parser.add_argument("-d", "--dataset", type=str, required=True,
                        choices=['aifb', 'mutag', 'bgs', 'am'],
                        help="dataset to use")
155
156
    parser.add_argument("--wd", type=float, default=5e-4,
                        help="weight decay")
Mufei Li's avatar
Mufei Li committed
157
158
159
160
161
162
163
164
165
166
    parser.add_argument("--fanout", type=str, default="4, 4",
                        help="Fan-out of neighbor sampling")
    parser.add_argument("--use-self-loop", default=False, action='store_true',
                        help="include self feature as a special relation")
    parser.add_argument("--batch-size", type=int, default=100,
                        help="Mini-batch size")
    args = parser.parse_args()

    print(args)
    main(args)