entity_classify_mb.py 7.83 KB
Newer Older
1
2
3
4
5
6
7
"""Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn
"""
import argparse
import itertools
import time
8

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
10
import dgl

11
import numpy as np
12
13
import torch as th
import torch.nn.functional as F
14
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
15
from model import EntityClassify, RelGraphEmbed
16

17

18
def extract_embed(node_embed, input_nodes):
19
    emb = {}
20
21
    for ntype, nid in input_nodes.items():
        nid = input_nodes[ntype]
22
23
24
        emb[ntype] = node_embed[ntype][nid]
    return emb

25

26
def evaluate(model, loader, node_embed, labels, category, device):
27
    model.eval()
28
29
30
31
    total_loss = 0
    total_acc = 0
    count = 0
    for input_nodes, seeds, blocks in loader:
32
        blocks = [blk.to(device) for blk in blocks]
33
34
        seeds = seeds[category]
        emb = extract_embed(node_embed, input_nodes)
35
        emb = {k: e.to(device) for k, e in emb.items()}
36
        lbl = labels[seeds].to(device)
37
38
39
40
41
42
43
        logits = model(emb, blocks)[category]
        loss = F.cross_entropy(logits, lbl)
        acc = th.sum(logits.argmax(dim=1) == lbl).item()
        total_loss += loss.item() * len(seeds)
        total_acc += acc
        count += len(seeds)
    return total_loss / count, total_acc / count
44

45

46
def main(args):
47
    # check cuda
48
    device = "cpu"
49
50
51
    use_cuda = args.gpu >= 0 and th.cuda.is_available()
    if use_cuda:
        th.cuda.set_device(args.gpu)
52
        device = "cuda:%d" % args.gpu
53

54
    # load graph data
55
    if args.dataset == "aifb":
56
        dataset = AIFBDataset()
57
    elif args.dataset == "mutag":
58
        dataset = MUTAGDataset()
59
    elif args.dataset == "bgs":
60
        dataset = BGSDataset()
61
    elif args.dataset == "am":
62
        dataset = AMDataset()
63
64
65
    else:
        raise ValueError()

66
    g = dataset[0]
67
68
    category = dataset.predict_category
    num_classes = dataset.num_classes
69
70
    train_mask = g.nodes[category].data.pop("train_mask")
    test_mask = g.nodes[category].data.pop("test_mask")
71
72
    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
73
    labels = g.nodes[category].data.pop("labels")
74
75
76

    # split dataset into train, validate, test
    if args.validation:
77
78
        val_idx = train_idx[: len(train_idx) // 5]
        train_idx = train_idx[len(train_idx) // 5 :]
79
80
81
82
83
    else:
        val_idx = train_idx

    # create embeddings
    embed_layer = RelGraphEmbed(g, args.n_hidden)
84
85
86
87
88

    if not args.data_cpu:
        labels = labels.to(device)
        embed_layer = embed_layer.to(device)

89
90
    node_embed = embed_layer()
    # create model
91
92
93
94
95
96
97
98
99
    model = EntityClassify(
        g,
        args.n_hidden,
        num_classes,
        num_bases=args.n_bases,
        num_hidden_layers=args.n_layers - 2,
        dropout=args.dropout,
        use_self_loop=args.use_self_loop,
    )
100
101
102
103
104

    if use_cuda:
        model.cuda()

    # train sampler
105
106
107
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [args.fanout] * args.n_layers
    )
108
    loader = dgl.dataloading.DataLoader(
109
110
111
112
113
114
115
        g,
        {category: train_idx},
        sampler,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0,
    )
116
117

    # validation sampler
118
    # we do not use full neighbor to save computation resources
119
120
121
    val_sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [args.fanout] * args.n_layers
    )
122
    val_loader = dgl.dataloading.DataLoader(
123
124
125
126
127
128
129
        g,
        {category: val_idx},
        val_sampler,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0,
    )
130
131
132
133
134
135
136
137
138
139
140
141
142
143

    # optimizer
    all_params = itertools.chain(model.parameters(), embed_layer.parameters())
    optimizer = th.optim.Adam(all_params, lr=args.lr, weight_decay=args.l2norm)

    # training loop
    print("start training...")
    dur = []
    for epoch in range(args.n_epochs):
        model.train()
        optimizer.zero_grad()
        if epoch > 3:
            t0 = time.time()

144
        for i, (input_nodes, seeds, blocks) in enumerate(loader):
145
            blocks = [blk.to(device) for blk in blocks]
146
147
148
            seeds = seeds[
                category
            ]  # we only predict the nodes with type "category"
149
            batch_tic = time.time()
150
151
            emb = extract_embed(node_embed, input_nodes)
            lbl = labels[seeds]
152
            if use_cuda:
153
                emb = {k: e.cuda() for k, e in emb.items()}
154
155
156
157
158
159
                lbl = lbl.cuda()
            logits = model(emb, blocks)[category]
            loss = F.cross_entropy(logits, lbl)
            loss.backward()
            optimizer.step()

160
            train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds)
161
162
163
164
165
            print(
                "Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}".format(
                    epoch, i, train_acc, loss.item(), time.time() - batch_tic
                )
            )
166
167
168
169

        if epoch > 3:
            dur.append(time.time() - t0)

170
171
172
173
174
175
176
177
        val_loss, val_acc = evaluate(
            model, val_loader, node_embed, labels, category, device
        )
        print(
            "Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".format(
                epoch, val_acc, val_loss, np.average(dur)
            )
        )
178
179
180
181
    print()
    if args.model_path is not None:
        th.save(model.state_dict(), args.model_path)

182
    output = model.inference(
183
184
        g, args.batch_size, "cuda" if use_cuda else "cpu", 0, node_embed
    )
185
    test_pred = output[category][test_idx]
186
    test_labels = labels[test_idx].to(test_pred.device)
187
188
    test_acc = (test_pred.argmax(1) == test_labels).float().mean()
    print("Test Acc: {:.4f}".format(test_acc))
189
190
    print()

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGCN")
    parser.add_argument(
        "--dropout", type=float, default=0, help="dropout probability"
    )
    parser.add_argument(
        "--n-hidden", type=int, default=16, help="number of hidden units"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    parser.add_argument(
        "--n-bases",
        type=int,
        default=-1,
        help="number of filter weight matrices, default: -1 [use all]",
    )
    parser.add_argument(
        "--n-layers", type=int, default=2, help="number of propagation rounds"
    )
    parser.add_argument(
        "-e",
        "--n-epochs",
        type=int,
        default=20,
        help="number of training epochs",
    )
    parser.add_argument(
        "-d", "--dataset", type=str, required=True, help="dataset to use"
    )
    parser.add_argument(
        "--model_path", type=str, default=None, help="path for save the model"
    )
    parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
    parser.add_argument(
        "--use-self-loop",
        default=False,
        action="store_true",
        help="include self feature as a special relation",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=100,
        help="Mini-batch size. If -1, use full graph training.",
    )
    parser.add_argument(
        "--fanout", type=int, default=4, help="Fan-out of neighbor sampling."
    )
    parser.add_argument(
        "--data-cpu",
        action="store_true",
        help="By default the script puts all node features and labels "
        "on GPU when using it to save time for data copy. This may "
        "be undesired if they cannot fit in GPU memory at once. "
        "This flag disables that.",
    )
248
    fp = parser.add_mutually_exclusive_group(required=False)
249
250
    fp.add_argument("--validation", dest="validation", action="store_true")
    fp.add_argument("--testing", dest="validation", action="store_false")
251
252
253
254
255
    parser.set_defaults(validation=True)

    args = parser.parse_args()
    print(args)
    main(args)