entity_classify_mb.py 8.46 KB
Newer Older
1
2
3
4
5
6
7
"""Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn
"""
import argparse
import itertools
import time
8

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
10
import dgl

11
import numpy as np
12
13
import torch as th
import torch.nn.functional as F
14
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
15
from model import EntityClassify, RelGraphEmbed
16

17

18
def extract_embed(node_embed, input_nodes):
19
    emb = {}
20
21
    for ntype, nid in input_nodes.items():
        nid = input_nodes[ntype]
22
23
24
        emb[ntype] = node_embed[ntype][nid]
    return emb

25

26
def evaluate(model, loader, node_embed, labels, category, device):
27
    model.eval()
28
29
30
    total_loss = 0
    total_acc = 0
    count = 0
31
32
33
34
35
36
37
38
39
40
41
42
43
    with loader.enable_cpu_affinity():
        for input_nodes, seeds, blocks in loader:
            blocks = [blk.to(device) for blk in blocks]
            seeds = seeds[category]
            emb = extract_embed(node_embed, input_nodes)
            emb = {k: e.to(device) for k, e in emb.items()}
            lbl = labels[seeds].to(device)
            logits = model(emb, blocks)[category]
            loss = F.cross_entropy(logits, lbl)
            acc = th.sum(logits.argmax(dim=1) == lbl).item()
            total_loss += loss.item() * len(seeds)
            total_acc += acc
            count += len(seeds)
44
    return total_loss / count, total_acc / count
45

46

47
def main(args):
48
    # check cuda
49
    device = "cpu"
50
51
52
    use_cuda = args.gpu >= 0 and th.cuda.is_available()
    if use_cuda:
        th.cuda.set_device(args.gpu)
53
        device = "cuda:%d" % args.gpu
54

55
    # load graph data
56
    if args.dataset == "aifb":
57
        dataset = AIFBDataset()
58
    elif args.dataset == "mutag":
59
        dataset = MUTAGDataset()
60
    elif args.dataset == "bgs":
61
        dataset = BGSDataset()
62
    elif args.dataset == "am":
63
        dataset = AMDataset()
64
65
66
    else:
        raise ValueError()

67
    g = dataset[0]
68
69
    category = dataset.predict_category
    num_classes = dataset.num_classes
70
71
    train_mask = g.nodes[category].data.pop("train_mask")
    test_mask = g.nodes[category].data.pop("test_mask")
72
73
    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
74
    labels = g.nodes[category].data.pop("labels")
75
76
77

    # split dataset into train, validate, test
    if args.validation:
78
79
        val_idx = train_idx[: len(train_idx) // 5]
        train_idx = train_idx[len(train_idx) // 5 :]
80
81
82
83
84
    else:
        val_idx = train_idx

    # create embeddings
    embed_layer = RelGraphEmbed(g, args.n_hidden)
85
86
87
88
89

    if not args.data_cpu:
        labels = labels.to(device)
        embed_layer = embed_layer.to(device)

90
91
92
93
94
95
    if args.num_workers <= 0:
        raise ValueError(
            "The '--num_workers' parameter value is expected "
            "to be >0, but got {}.".format(args.num_workers)
        )

96
97
    node_embed = embed_layer()
    # create model
98
99
100
101
102
103
104
105
106
    model = EntityClassify(
        g,
        args.n_hidden,
        num_classes,
        num_bases=args.n_bases,
        num_hidden_layers=args.n_layers - 2,
        dropout=args.dropout,
        use_self_loop=args.use_self_loop,
    )
107
108
109
110
111

    if use_cuda:
        model.cuda()

    # train sampler
112
113
114
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [args.fanout] * args.n_layers
    )
115
    loader = dgl.dataloading.DataLoader(
116
117
118
119
120
        g,
        {category: train_idx},
        sampler,
        batch_size=args.batch_size,
        shuffle=True,
121
        num_workers=args.num_workers,
122
    )
123
124

    # validation sampler
125
    # we do not use full neighbor to save computation resources
126
127
128
    val_sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [args.fanout] * args.n_layers
    )
129
    val_loader = dgl.dataloading.DataLoader(
130
131
132
133
134
        g,
        {category: val_idx},
        val_sampler,
        batch_size=args.batch_size,
        shuffle=True,
135
        num_workers=args.num_workers,
136
    )
137
138
139
140
141
142
143

    # optimizer
    all_params = itertools.chain(model.parameters(), embed_layer.parameters())
    optimizer = th.optim.Adam(all_params, lr=args.lr, weight_decay=args.l2norm)

    # training loop
    print("start training...")
144
    mean = 0
145
146
147
148
149
150
    for epoch in range(args.n_epochs):
        model.train()
        optimizer.zero_grad()
        if epoch > 3:
            t0 = time.time()

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        with loader.enable_cpu_affinity():
            for i, (input_nodes, seeds, blocks) in enumerate(loader):
                blocks = [blk.to(device) for blk in blocks]
                seeds = seeds[
                    category
                ]  # we only predict the nodes with type "category"
                batch_tic = time.time()
                emb = extract_embed(node_embed, input_nodes)
                lbl = labels[seeds]
                if use_cuda:
                    emb = {k: e.cuda() for k, e in emb.items()}
                    lbl = lbl.cuda()
                logits = model(emb, blocks)[category]
                loss = F.cross_entropy(logits, lbl)
                loss.backward()
                optimizer.step()
167

168
169
170
171
172
173
174
                train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(
                    seeds
                )
                print(
                    f"Epoch {epoch:05d} | Batch {i:03d} | Train Acc: "
                    "{train_acc:.4f} | Train Loss: {loss.item():.4f} | Time: "
                    "{time.time() - batch_tic:.4f}"
175
                )
176
177

        if epoch > 3:
178
            mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)
179

180
181
182
183
184
185
            val_loss, val_acc = evaluate(
                model, val_loader, node_embed, labels, category, device
            )
            print(
                f"Epoch {epoch:05d} | Valid Acc: {val_acc:.4f} | Valid loss: "
                "{val_loss:.4f} | Time: {mean:.4f}"
186
            )
187
188
189
190
    print()
    if args.model_path is not None:
        th.save(model.state_dict(), args.model_path)

191
    output = model.inference(
192
193
194
195
196
        g,
        args.batch_size,
        "cuda" if use_cuda else "cpu",
        args.num_workers,
        node_embed,
197
    )
198
    test_pred = output[category][test_idx]
199
    test_labels = labels[test_idx].to(test_pred.device)
200
201
    test_acc = (test_pred.argmax(1) == test_labels).float().mean()
    print("Test Acc: {:.4f}".format(test_acc))
202
203
    print()

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGCN")
    parser.add_argument(
        "--dropout", type=float, default=0, help="dropout probability"
    )
    parser.add_argument(
        "--n-hidden", type=int, default=16, help="number of hidden units"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    parser.add_argument(
        "--n-bases",
        type=int,
        default=-1,
        help="number of filter weight matrices, default: -1 [use all]",
    )
    parser.add_argument(
        "--n-layers", type=int, default=2, help="number of propagation rounds"
    )
    parser.add_argument(
        "-e",
        "--n-epochs",
        type=int,
        default=20,
        help="number of training epochs",
    )
    parser.add_argument(
        "-d", "--dataset", type=str, required=True, help="dataset to use"
    )
    parser.add_argument(
        "--model_path", type=str, default=None, help="path for save the model"
    )
    parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
    parser.add_argument(
        "--use-self-loop",
        default=False,
        action="store_true",
        help="include self feature as a special relation",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=100,
        help="Mini-batch size. If -1, use full graph training.",
    )
    parser.add_argument(
        "--fanout", type=int, default=4, help="Fan-out of neighbor sampling."
    )
    parser.add_argument(
        "--data-cpu",
        action="store_true",
        help="By default the script puts all node features and labels "
        "on GPU when using it to save time for data copy. This may "
        "be undesired if they cannot fit in GPU memory at once. "
        "This flag disables that.",
    )
261
262
263
264
    parser.add_argument(
        "--num_workers", type=int, default=4, help="Number of node dataloader"
    )

265
    fp = parser.add_mutually_exclusive_group(required=False)
266
267
    fp.add_argument("--validation", dest="validation", action="store_true")
    fp.add_argument("--testing", dest="validation", action="store_false")
268
269
270
271
272
    parser.set_defaults(validation=True)

    args = parser.parse_args()
    print(args)
    main(args)