entity_classify_mb.py 7.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn
"""
import argparse
import itertools
import numpy as np
import time
import torch as th
import torch.nn.functional as F

import dgl
13
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
14
from model import EntityClassify, RelGraphEmbed
15

16
def extract_embed(node_embed, input_nodes):
17
    emb = {}
18
19
    for ntype, nid in input_nodes.items():
        nid = input_nodes[ntype]
20
21
22
        emb[ntype] = node_embed[ntype][nid]
    return emb

23
def evaluate(model, loader, node_embed, labels, category, device):
24
    model.eval()
25
26
27
28
    total_loss = 0
    total_acc = 0
    count = 0
    for input_nodes, seeds, blocks in loader:
29
        blocks = [blk.to(device) for blk in blocks]
30
31
        seeds = seeds[category]
        emb = extract_embed(node_embed, input_nodes)
32
        emb = {k: e.to(device) for k, e in emb.items()}
33
        lbl = labels[seeds].to(device)
34
35
36
37
38
39
40
        logits = model(emb, blocks)[category]
        loss = F.cross_entropy(logits, lbl)
        acc = th.sum(logits.argmax(dim=1) == lbl).item()
        total_loss += loss.item() * len(seeds)
        total_acc += acc
        count += len(seeds)
    return total_loss / count, total_acc / count
41
42

def main(args):
43
44
45
46
47
48
49
    # check cuda
    device = 'cpu'
    use_cuda = args.gpu >= 0 and th.cuda.is_available()
    if use_cuda:
        th.cuda.set_device(args.gpu)
        device = 'cuda:%d' % args.gpu

50
51
    # load graph data
    if args.dataset == 'aifb':
52
        dataset = AIFBDataset()
53
    elif args.dataset == 'mutag':
54
        dataset = MUTAGDataset()
55
    elif args.dataset == 'bgs':
56
        dataset = BGSDataset()
57
    elif args.dataset == 'am':
58
        dataset = AMDataset()
59
60
61
    else:
        raise ValueError()

62
    g = dataset[0]
63
64
    category = dataset.predict_category
    num_classes = dataset.num_classes
65
66
    train_mask = g.nodes[category].data.pop('train_mask')
    test_mask = g.nodes[category].data.pop('test_mask')
67
68
    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
69
    labels = g.nodes[category].data.pop('labels')
70
71
72
73
74
75
76
77
78
79

    # split dataset into train, validate, test
    if args.validation:
        val_idx = train_idx[:len(train_idx) // 5]
        train_idx = train_idx[len(train_idx) // 5:]
    else:
        val_idx = train_idx

    # create embeddings
    embed_layer = RelGraphEmbed(g, args.n_hidden)
80
81
82
83
84

    if not args.data_cpu:
        labels = labels.to(device)
        embed_layer = embed_layer.to(device)

85
86
87
88
89
90
91
92
93
94
95
96
97
98
    node_embed = embed_layer()
    # create model
    model = EntityClassify(g,
                           args.n_hidden,
                           num_classes,
                           num_bases=args.n_bases,
                           num_hidden_layers=args.n_layers - 2,
                           dropout=args.dropout,
                           use_self_loop=args.use_self_loop)

    if use_cuda:
        model.cuda()

    # train sampler
99
    sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
100
    loader = dgl.dataloading.DataLoader(
101
102
        g, {category: train_idx}, sampler,
        batch_size=args.batch_size, shuffle=True, num_workers=0)
103
104

    # validation sampler
105
    # we do not use full neighbor to save computation resources
106
    val_sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
107
    val_loader = dgl.dataloading.DataLoader(
108
109
        g, {category: val_idx}, val_sampler,
        batch_size=args.batch_size, shuffle=True, num_workers=0)
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    # optimizer
    all_params = itertools.chain(model.parameters(), embed_layer.parameters())
    optimizer = th.optim.Adam(all_params, lr=args.lr, weight_decay=args.l2norm)

    # training loop
    print("start training...")
    dur = []
    for epoch in range(args.n_epochs):
        model.train()
        optimizer.zero_grad()
        if epoch > 3:
            t0 = time.time()

124
        for i, (input_nodes, seeds, blocks) in enumerate(loader):
125
            blocks = [blk.to(device) for blk in blocks]
126
            seeds = seeds[category]     # we only predict the nodes with type "category"
127
            batch_tic = time.time()
128
129
            emb = extract_embed(node_embed, input_nodes)
            lbl = labels[seeds]
130
131
132
133
134
135
136
137
            if use_cuda:
                emb = {k : e.cuda() for k, e in emb.items()}
                lbl = lbl.cuda()
            logits = model(emb, blocks)[category]
            loss = F.cross_entropy(logits, lbl)
            loss.backward()
            optimizer.step()

138
            train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds)
139
140
141
142
143
144
            print("Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}".
                  format(epoch, i, train_acc, loss.item(), time.time() - batch_tic))

        if epoch > 3:
            dur.append(time.time() - t0)

145
        val_loss, val_acc = evaluate(model, val_loader, node_embed, labels, category, device)
146
        print("Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".
147
              format(epoch, val_acc, val_loss, np.average(dur)))
148
149
150
151
    print()
    if args.model_path is not None:
        th.save(model.state_dict(), args.model_path)

152
153
154
    output = model.inference(
        g, args.batch_size, 'cuda' if use_cuda else 'cpu', 0, node_embed)
    test_pred = output[category][test_idx]
155
    test_labels = labels[test_idx].to(test_pred.device)
156
157
    test_acc = (test_pred.argmax(1) == test_labels).float().mean()
    print("Test Acc: {:.4f}".format(test_acc))
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    print()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN')
    parser.add_argument("--dropout", type=float, default=0,
            help="dropout probability")
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden units")
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2,
            help="learning rate")
    parser.add_argument("--n-bases", type=int, default=-1,
            help="number of filter weight matrices, default: -1 [use all]")
    parser.add_argument("--n-layers", type=int, default=2,
            help="number of propagation rounds")
    parser.add_argument("-e", "--n-epochs", type=int, default=20,
            help="number of training epochs")
    parser.add_argument("-d", "--dataset", type=str, required=True,
            help="dataset to use")
    parser.add_argument("--model_path", type=str, default=None,
            help='path for save the model')
    parser.add_argument("--l2norm", type=float, default=0,
            help="l2 norm coef")
    parser.add_argument("--use-self-loop", default=False, action='store_true',
            help="include self feature as a special relation")
    parser.add_argument("--batch-size", type=int, default=100,
            help="Mini-batch size. If -1, use full graph training.")
    parser.add_argument("--fanout", type=int, default=4,
            help="Fan-out of neighbor sampling.")
188
189
190
191
192
    parser.add_argument('--data-cpu', action='store_true',
            help="By default the script puts all node features and labels "
                 "on GPU when using it to save time for data copy. This may "
                 "be undesired if they cannot fit in GPU memory at once. "
                 "This flag disables that.")
193
194
195
196
197
198
199
200
    fp = parser.add_mutually_exclusive_group(required=False)
    fp.add_argument('--validation', dest='validation', action='store_true')
    fp.add_argument('--testing', dest='validation', action='store_false')
    parser.set_defaults(validation=True)

    args = parser.parse_args()
    print(args)
    main(args)