data_utils.py 13.8 KB
Newer Older
1
2
3
4
5
"""
Data utils for processing bAbI datasets
"""

import os
6
import string
7

8
import torch
9
from torch.utils.data import DataLoader
10

11
import dgl
12
13
14
15
16
17
from dgl.data.utils import (
    _get_dgl_url,
    download,
    extract_archive,
    get_download_dir,
)
18
19
20
21
22


def get_babi_dataloaders(batch_size, train_size=50, task_id=4, q_type=0):
    _download_babi_data()

23
24
25
    node_dict = dict(
        zip(list(string.ascii_uppercase), range(len(string.ascii_uppercase)))
    )
26
27

    if task_id == 4:
28
        edge_dict = {"n": 0, "s": 1, "w": 2, "e": 3}
29
        reverse_edge = {}
30
31
32
33
34
35
36
37
38
        return _ns_dataloader(
            train_size,
            q_type,
            batch_size,
            node_dict,
            edge_dict,
            reverse_edge,
            "04",
        )
39
    elif task_id == 15:
40
        edge_dict = {"is": 0, "has_fear": 1}
41
        reverse_edge = {}
42
43
44
45
46
47
48
49
50
        return _ns_dataloader(
            train_size,
            q_type,
            batch_size,
            node_dict,
            edge_dict,
            reverse_edge,
            "15",
        )
51
    elif task_id == 16:
52
        edge_dict = {"is": 0, "has_color": 1}
53
        reverse_edge = {0: 0}
54
55
56
57
58
59
60
61
62
        return _ns_dataloader(
            train_size,
            q_type,
            batch_size,
            node_dict,
            edge_dict,
            reverse_edge,
            "16",
        )
63
    elif task_id == 18:
64
65
        edge_dict = {">": 0, "<": 1}
        label_dict = {"false": 0, "true": 1}
66
        reverse_edge = {0: 1, 1: 0}
67
68
69
70
71
72
73
74
75
76
        return _gc_dataloader(
            train_size,
            q_type,
            batch_size,
            node_dict,
            edge_dict,
            label_dict,
            reverse_edge,
            "18",
        )
77
    elif task_id == 19:
78
        edge_dict = {"n": 0, "s": 1, "w": 2, "e": 3, "<end>": 4}
79
80
        reverse_edge = {0: 1, 1: 0, 2: 3, 3: 2}
        max_seq_length = 2
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        return _path_finding_dataloader(
            train_size,
            batch_size,
            node_dict,
            edge_dict,
            reverse_edge,
            "19",
            max_seq_length,
        )


def _ns_dataloader(
    train_size, q_type, batch_size, node_dict, edge_dict, reverse_edge, path
):
95
96
97
98
    def _collate_fn(batch):
        graphs = []
        labels = []
        for d in batch:
99
            edges = d["edges"]
100
101
102
103
104
105
106
107
108

            node_ids = []
            for s, e, t in edges:
                if s not in node_ids:
                    node_ids.append(s)
                if t not in node_ids:
                    node_ids.append(t)
            g = dgl.DGLGraph()
            g.add_nodes(len(node_ids))
109
            g.ndata["node_id"] = torch.tensor(node_ids, dtype=torch.long)
110
111
112
113

            nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))

            # convert label to node index
114
            label = d["eval"][2]
115
116
117
118
119
120
121
122
123
124
            label_idx = nid2idx[label]
            labels.append(label_idx)

            edge_types = []
            for s, e, t in edges:
                g.add_edge(nid2idx[s], nid2idx[t])
                edge_types.append(e)
                if e in reverse_edge:
                    g.add_edge(nid2idx[t], nid2idx[s])
                    edge_types.append(reverse_edge[e])
125
            g.edata["type"] = torch.tensor(edge_types, dtype=torch.long)
126
            annotation = torch.zeros(len(node_ids), dtype=torch.long)
127
128
            annotation[nid2idx[d["eval"][0]]] = 1
            g.ndata["annotation"] = annotation.unsqueeze(-1)
129
130
131
132
133
134
            graphs.append(g)
        batch_graph = dgl.batch(graphs)
        labels = torch.tensor(labels, dtype=torch.long)
        return batch_graph, labels

    def _get_dataloader(data, shuffle):
135
136
137
138
139
140
141
142
143
144
        return DataLoader(
            dataset=data,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=_collate_fn,
        )

    train_set, dev_set, test_sets = _convert_ns_dataset(
        train_size, node_dict, edge_dict, path, q_type
    )
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    train_dataloader = _get_dataloader(train_set, True)
    dev_dataloader = _get_dataloader(dev_set, False)
    test_dataloaders = []
    for d in test_sets:
        dl = _get_dataloader(d, False)
        test_dataloaders.append(dl)

    return train_dataloader, dev_dataloader, test_dataloaders


def _convert_ns_dataset(train_size, node_dict, edge_dict, path, q_type):
    total_num = 11000

    def convert(file):
        dataset = []
        d = dict()
161
        with open(file, "r") as f:
162
163
            for i, line in enumerate(f.readlines()):
                line = line.strip().split()
164
                if line[0] == "1" and len(d) > 0:
165
                    d = dict()
166
                if line[1] == "eval":
167
                    # (src, edge, label)
168
169
170
171
172
173
                    d["eval"] = (
                        node_dict[line[2]],
                        edge_dict[line[3]],
                        node_dict[line[4]],
                    )
                    if d["eval"][1] == q_type:
174
175
176
177
                        dataset.append(d)
                        if len(dataset) >= total_num:
                            break
                else:
178
179
180
181
182
183
184
185
186
                    if "edges" not in d:
                        d["edges"] = []
                    d["edges"].append(
                        (
                            node_dict[line[1]],
                            edge_dict[line[2]],
                            node_dict[line[3]],
                        )
                    )
187
188
189
        return dataset

    download_dir = get_download_dir()
190
    filename = os.path.join(download_dir, "babi_data", path, "data.txt")
191
192
193
194
195
196
197
198
    data = convert(filename)

    assert len(data) == total_num

    train_set = data[:train_size]
    dev_set = data[950:1000]
    test_sets = []
    for i in range(10):
199
        test = data[1000 * (i + 1) : 1000 * (i + 2)]
200
201
202
203
204
        test_sets.append(test)

    return train_set, dev_set, test_sets


205
206
207
208
209
210
211
212
213
214
def _gc_dataloader(
    train_size,
    q_type,
    batch_size,
    node_dict,
    edge_dict,
    label_dict,
    reverse_edge,
    path,
):
215
216
217
218
    def _collate_fn(batch):
        graphs = []
        labels = []
        for d in batch:
219
            edges = d["edges"]
220
221
222
223
224
225
226
227
228

            node_ids = []
            for s, e, t in edges:
                if s not in node_ids:
                    node_ids.append(s)
                if t not in node_ids:
                    node_ids.append(t)
            g = dgl.DGLGraph()
            g.add_nodes(len(node_ids))
229
            g.ndata["node_id"] = torch.tensor(node_ids, dtype=torch.long)
230
231
232

            nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))

233
            labels.append(d["eval"][-1])
234
235
236
237
238
239
240
241

            edge_types = []
            for s, e, t in edges:
                g.add_edge(nid2idx[s], nid2idx[t])
                edge_types.append(e)
                if e in reverse_edge:
                    g.add_edge(nid2idx[t], nid2idx[s])
                    edge_types.append(reverse_edge[e])
242
            g.edata["type"] = torch.tensor(edge_types, dtype=torch.long)
243
            annotation = torch.zeros([len(node_ids), 2], dtype=torch.long)
244
245
246
            annotation[nid2idx[d["eval"][0]]][0] = 1
            annotation[nid2idx[d["eval"][2]]][1] = 1
            g.ndata["annotation"] = annotation
247
248
249
250
251
252
            graphs.append(g)
        batch_graph = dgl.batch(graphs)
        labels = torch.tensor(labels, dtype=torch.long)
        return batch_graph, labels

    def _get_dataloader(data, shuffle):
253
254
255
256
257
258
259
260
261
262
        return DataLoader(
            dataset=data,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=_collate_fn,
        )

    train_set, dev_set, test_sets = _convert_gc_dataset(
        train_size, node_dict, edge_dict, label_dict, path, q_type
    )
263
264
265
266
267
268
269
270
271
272
    train_dataloader = _get_dataloader(train_set, True)
    dev_dataloader = _get_dataloader(dev_set, False)
    test_dataloaders = []
    for d in test_sets:
        dl = _get_dataloader(d, False)
        test_dataloaders.append(dl)

    return train_dataloader, dev_dataloader, test_dataloaders


273
274
275
def _convert_gc_dataset(
    train_size, node_dict, edge_dict, label_dict, path, q_type
):
276
277
278
279
280
    total_num = 11000

    def convert(file):
        dataset = []
        d = dict()
281
        with open(file, "r") as f:
282
283
            for i, line in enumerate(f.readlines()):
                line = line.strip().split()
284
                if line[0] == "1" and len(d) > 0:
285
                    d = dict()
286
                if line[1] == "eval":
287
                    # (src, edge, label)
288
289
290
291
292
293
294
295
                    if "eval" not in d:
                        d["eval"] = (
                            node_dict[line[2]],
                            edge_dict[line[3]],
                            node_dict[line[4]],
                            label_dict[line[5]],
                        )
                        if d["eval"][1] == q_type:
296
297
298
299
                            dataset.append(d)
                            if len(dataset) >= total_num:
                                break
                else:
300
301
302
303
304
305
306
307
308
                    if "edges" not in d:
                        d["edges"] = []
                    d["edges"].append(
                        (
                            node_dict[line[1]],
                            edge_dict[line[2]],
                            node_dict[line[3]],
                        )
                    )
309
310
311
        return dataset

    download_dir = get_download_dir()
312
    filename = os.path.join(download_dir, "babi_data", path, "data.txt")
313
314
315
316
317
318
319
320
    data = convert(filename)

    assert len(data) == total_num

    train_set = data[:train_size]
    dev_set = data[950:1000]
    test_sets = []
    for i in range(10):
321
        test = data[1000 * (i + 1) : 1000 * (i + 2)]
322
323
324
325
326
        test_sets.append(test)

    return train_set, dev_set, test_sets


327
328
329
330
331
332
333
334
335
def _path_finding_dataloader(
    train_size,
    batch_size,
    node_dict,
    edge_dict,
    reverse_edge,
    path,
    max_seq_length,
):
336
337
338
339
340
    def _collate_fn(batch):
        graphs = []
        ground_truths = []
        seq_lengths = []
        for d in batch:
341
            edges = d["edges"]
342
343
344
345
346
347
348
349
350

            node_ids = []
            for s, e, t in edges:
                if s not in node_ids:
                    node_ids.append(s)
                if t not in node_ids:
                    node_ids.append(t)
            g = dgl.DGLGraph()
            g.add_nodes(len(node_ids))
351
            g.ndata["node_id"] = torch.tensor(node_ids, dtype=torch.long)
352
353
354

            nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))

355
356
357
358
            truth = d["seq_out"] + [edge_dict["<end>"]] * (
                max_seq_length - len(d["seq_out"])
            )
            seq_len = len(d["seq_out"])
359
360
361
362
363
364
365
366
367
368
            ground_truths.append(truth)
            seq_lengths.append(seq_len)

            edge_types = []
            for s, e, t in edges:
                g.add_edge(nid2idx[s], nid2idx[t])
                edge_types.append(e)
                if e in reverse_edge:
                    g.add_edge(nid2idx[t], nid2idx[s])
                    edge_types.append(reverse_edge[e])
369
            g.edata["type"] = torch.tensor(edge_types, dtype=torch.long)
370
            annotation = torch.zeros([len(node_ids), 2], dtype=torch.long)
371
372
373
            annotation[nid2idx[d["eval"][0]]][0] = 1
            annotation[nid2idx[d["eval"][1]]][1] = 1
            g.ndata["annotation"] = annotation
374
375
376
377
378
379
380
            graphs.append(g)
        batch_graph = dgl.batch(graphs)
        ground_truths = torch.tensor(ground_truths, dtype=torch.long)
        seq_lengths = torch.tensor(seq_lengths, dtype=torch.long)
        return batch_graph, ground_truths, seq_lengths

    def _get_dataloader(data, shuffle):
381
382
383
384
385
386
387
388
389
390
        return DataLoader(
            dataset=data,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=_collate_fn,
        )

    train_set, dev_set, test_sets = _convert_path_finding(
        train_size, node_dict, edge_dict, path
    )
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    train_dataloader = _get_dataloader(train_set, True)
    dev_dataloader = _get_dataloader(dev_set, False)
    test_dataloaders = []
    for d in test_sets:
        dl = _get_dataloader(d, False)
        test_dataloaders.append(dl)

    return train_dataloader, dev_dataloader, test_dataloaders


def _convert_path_finding(train_size, node_dict, edge_dict, path):
    total_num = 11000

    def convert(file):
        dataset = []
        d = dict()
407
        with open(file, "r") as f:
408
409
            for line in f.readlines():
                line = line.strip().split()
410
                if line[0] == "1" and len(d) > 0:
411
                    d = dict()
412
                if line[1] == "eval":
413
                    # (src, edge, label)
414
415
416
                    d["eval"] = (node_dict[line[3]], node_dict[line[4]])
                    d["seq_out"] = []
                    seq_out = line[5].split(",")
417
                    for e in seq_out:
418
                        d["seq_out"].append(edge_dict[e])
419
420
421
422
                    dataset.append(d)
                    if len(dataset) >= total_num:
                        break
                else:
423
424
425
426
427
428
429
430
431
                    if "edges" not in d:
                        d["edges"] = []
                    d["edges"].append(
                        (
                            node_dict[line[1]],
                            edge_dict[line[2]],
                            node_dict[line[3]],
                        )
                    )
432
433
434
        return dataset

    download_dir = get_download_dir()
435
    filename = os.path.join(download_dir, "babi_data", path, "data.txt")
436
437
438
439
440
441
442
443
    data = convert(filename)

    assert len(data) == total_num

    train_set = data[:train_size]
    dev_set = data[950:1000]
    test_sets = []
    for i in range(10):
444
        test = data[1000 * (i + 1) : 1000 * (i + 2)]
445
446
447
448
449
450
451
        test_sets.append(test)

    return train_set, dev_set, test_sets


def _download_babi_data():
    download_dir = get_download_dir()
452
    zip_file_path = os.path.join(download_dir, "babi_data.zip")
453

454
    data_url = _get_dgl_url("models/ggnn_babi_data.zip")
455
456
    download(data_url, path=zip_file_path)

457
    extract_dir = os.path.join(download_dir, "babi_data")
458
    if not os.path.exists(extract_dir):
459
        extract_archive(zip_file_path, extract_dir)