utlis.py 15.8 KB
Newer Older
1
2
3
4
import json
import pickle
import random

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
import dgl

7
8
9
10
import numpy as np
import torch

NODE_TYPE = {"entity": 0, "root": 1, "relation": 2}
11
12
13
14
15
16
17
18
19


def write_txt(batch, seqs, w_file, args):
    # converting the prediction to real text.
    ret = []
    for b, seq in enumerate(seqs):
        txt = []
        for token in seq:
            # copy the entity
20
21
22
23
24
            if token >= len(args.text_vocab):
                ent_text = batch["raw_ent_text"][b][
                    token - len(args.text_vocab)
                ]
                ent_text = filter(lambda x: x != "<PAD>", ent_text)
25
26
                txt.extend(ent_text)
            else:
27
28
29
                if int(token) not in [
                    args.text_vocab(x) for x in ["<PAD>", "<BOS>", "<EOS>"]
                ]:
30
                    txt.append(args.text_vocab(int(token)))
31
            if int(token) == args.text_vocab("<EOS>"):
32
                break
33
34
35
        w_file.write(" ".join([str(x) for x in txt]) + "\n")
        ret.append([" ".join([str(x) for x in txt])])
    return ret
36
37
38
39


def replace_ent(x, ent, V):
    # replace the entity
40
41
    mask = x >= V
    if mask.sum() == 0:
42
43
        return x
    nz = mask.nonzero()
44
    fill_ent = ent[nz, x[mask] - V]
45
46
47
48
49
50
    x = x.masked_scatter(mask, fill_ent)
    return x


def len2mask(lens, device):
    max_len = max(lens)
51
52
53
54
55
    mask = (
        torch.arange(max_len, device=device)
        .unsqueeze(0)
        .expand(len(lens), max_len)
    )
56
57
58
59
    mask = mask >= torch.LongTensor(lens).to(mask).unsqueeze(1)
    return mask


60
def pad(var_len_list, out_type="list", flatten=False):
61
62
63
64
    if flatten:
        lens = [len(x) for x in var_len_list]
        var_len_list = sum(var_len_list, [])
    max_len = max([len(x) for x in var_len_list])
65
    if out_type == "list":
66
        if flatten:
67
68
69
            return [
                x + ["<PAD>"] * (max_len - len(x)) for x in var_len_list
            ], lens
70
        else:
71
72
            return [x + ["<PAD>"] * (max_len - len(x)) for x in var_len_list]
    if out_type == "tensor":
73
        if flatten:
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
            return (
                torch.stack(
                    [
                        torch.cat(
                            [
                                x,
                                torch.zeros(
                                    [max_len - len(x)] + list(x.shape[1:])
                                ).type_as(x),
                            ],
                            0,
                        )
                        for x in var_len_list
                    ],
                    0,
                ),
                lens,
            )
92
        else:
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            return torch.stack(
                [
                    torch.cat(
                        [
                            x,
                            torch.zeros(
                                [max_len - len(x)] + list(x.shape[1:])
                            ).type_as(x),
                        ],
                        0,
                    )
                    for x in var_len_list
                ],
                0,
            )
108
109
110


class Vocab(object):
111
112
113
114
115
116
    def __init__(
        self,
        max_vocab=2**31,
        min_freq=-1,
        sp=["<PAD>", "<BOS>", "<EOS>", "<UNK>"],
    ):
117
118
119
120
121
122
123
124
125
        self.i2s = []
        self.s2i = {}
        self.wf = {}
        self.max_vocab, self.min_freq, self.sp = max_vocab, min_freq, sp

    def __len__(self):
        return len(self.i2s)

    def __str__(self):
126
        return "Total " + str(len(self.i2s)) + str(self.i2s[:10])
127
128
129
130
131
132
133
134
135
136

    def update(self, token):
        if isinstance(token, list):
            for t in token:
                self.update(t)
        else:
            self.wf[token] = self.wf.get(token, 0) + 1

    def build(self):
        self.i2s.extend(self.sp)
137
138
139
140
141
142
143
        sort_kv = sorted(self.wf.items(), key=lambda x: x[1], reverse=True)
        for k, v in sort_kv:
            if (
                len(self.i2s) < self.max_vocab
                and v >= self.min_freq
                and k not in self.sp
            ):
144
145
146
147
148
149
150
                self.i2s.append(k)
        self.s2i.update(list(zip(self.i2s, range(len(self.i2s)))))

    def __call__(self, x):
        if isinstance(x, int):
            return self.i2s[x]
        else:
151
152
            return self.s2i.get(x, self.s2i["<UNK>"])

153
154
155
156
157
158
    def save(self, fname):
        pass

    def load(self, fname):
        pass

159

160
161
162
def at_least(x):
    # handling the illegal data
    if len(x) == 0:
163
        return ["<UNK>"]
164
165
166
    else:
        return x

167

168
169
170
171
172
class Example(object):
    def __init__(self, title, ent_text, ent_type, rel, text):
        # one object corresponds to a data sample
        self.raw_title = title.split()
        self.raw_ent_text = [at_least(x.split()) for x in ent_text]
173
174
175
176
177
        assert min([len(x) for x in self.raw_ent_text]) > 0, str(
            self.raw_ent_text
        )
        self.raw_ent_type = ent_type.split()  # <method> .. <>
        self.raw_rel = []
178
179
180
        for r in rel:
            rel_list = r.split()
            for i in range(len(rel_list)):
181
182
183
184
185
186
187
188
189
190
191
192
193
194
                if (
                    i > 0
                    and i < len(rel_list) - 1
                    and rel_list[i - 1] == "--"
                    and rel_list[i] != rel_list[i].lower()
                    and rel_list[i + 1] == "--"
                ):
                    self.raw_rel.append(
                        [
                            rel_list[: i - 1],
                            rel_list[i - 1] + rel_list[i] + rel_list[i + 1],
                            rel_list[i + 2 :],
                        ]
                    )
195
                    break
196
        self.raw_text = text.split()
197
198
199
        self.graph = self.build_graph()

    def __str__(self):
200
201
202
        return "\n".join(
            [str(k) + ":\t" + str(v) for k, v in self.__dict__.items()]
        )
203
204
205
206
207
208

    def __len__(self):
        return len(self.raw_text)

    @staticmethod
    def from_json(json_data):
209
210
211
212
213
214
215
        return Example(
            json_data["title"],
            json_data["entities"],
            json_data["types"],
            json_data["relations"],
            json_data["abstract"],
        )
216
217
218
219

    def build_graph(self):
        graph = dgl.DGLGraph()
        ent_len = len(self.raw_ent_text)
220
221
222
223
224
225
226
227
228
229
230
231
        rel_len = len(
            self.raw_rel
        )  # treat the repeated relation as different nodes, refer to the author's code

        graph.add_nodes(
            ent_len, {"type": torch.ones(ent_len) * NODE_TYPE["entity"]}
        )
        graph.add_nodes(1, {"type": torch.ones(1) * NODE_TYPE["root"]})
        graph.add_nodes(
            rel_len * 2,
            {"type": torch.ones(rel_len * 2) * NODE_TYPE["relation"]},
        )
232
233
        graph.add_edges(ent_len, torch.arange(ent_len))
        graph.add_edges(torch.arange(ent_len), ent_len)
234
235
236
237
        graph.add_edges(
            torch.arange(ent_len + 1 + rel_len * 2),
            torch.arange(ent_len + 1 + rel_len * 2),
        )
238
239
        adj_edges = []
        for i, r in enumerate(self.raw_rel):
240
            assert len(r) == 3, str(r)
241
            st, rt, ed = r
242
243
244
            st_ent, ed_ent = self.raw_ent_text.index(
                st
            ), self.raw_ent_text.index(ed)
245
            # according to the edge_softmax operator, we need to reverse the graph
246
247
248
249
            adj_edges.append([ent_len + 1 + 2 * i, st_ent])
            adj_edges.append([ed_ent, ent_len + 1 + 2 * i])
            adj_edges.append([ent_len + 1 + 2 * i + 1, ed_ent])
            adj_edges.append([st_ent, ent_len + 1 + 2 * i + 1])
250

251
        if len(adj_edges) > 0:
252
253
254
            graph.add_edges(*list(map(list, zip(*adj_edges))))
        return graph

255
256
257
258
    def get_tensor(
        self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab
    ):
        if hasattr(self, "_cached_tensor"):
259
260
            return self._cached_tensor
        else:
261
            title_data = ["<BOS>"] + self.raw_title + ["<EOS>"]
262
            title = [title_vocab(x) for x in title_data]
263
264
265
266
267
268
269
270
271
            ent_text = [
                [ent_text_vocab(y) for y in x] for x in self.raw_ent_text
            ]
            ent_type = [
                text_vocab(x) for x in self.raw_ent_type
            ]  # for inference
            rel_data = ["--root--"] + sum(
                [[x[1], x[1] + "_INV"] for x in self.raw_rel], []
            )
272
273
            rel = [rel_vocab(x) for x in rel_data]

274
            text_data = ["<BOS>"] + self.raw_text + ["<EOS>"]
275
276
277
278
            text = [text_vocab(x) for x in text_data]
            tgt_text = []
            # the input text and decoding target are different since the consideration of the copy mechanism.
            for i, str1 in enumerate(text_data):
279
280
281
282
                if str1[0] == "<" and str1[-1] == ">" and "_" in str1:
                    a, b = str1[1:-1].split("_")
                    text[i] = text_vocab("<" + a + ">")
                    tgt_text.append(len(text_vocab) + int(b))
283
284
                else:
                    tgt_text.append(text[i])
285
286
287
288
289
290
291
292
293
294
            self._cached_tensor = {
                "title": torch.LongTensor(title),
                "ent_text": [torch.LongTensor(x) for x in ent_text],
                "ent_type": torch.LongTensor(ent_type),
                "rel": torch.LongTensor(rel),
                "text": torch.LongTensor(text[:-1]),
                "tgt_text": torch.LongTensor(tgt_text[1:]),
                "graph": self.graph,
                "raw_ent_text": self.raw_ent_text,
            }
295
296
            return self._cached_tensor

297
298
299
    def update_vocab(
        self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab
    ):
300
301
302
        ent_vocab.update(self.raw_ent_type)
        ent_text_vocab.update(self.raw_ent_text)
        title_vocab.update(self.raw_title)
303
304
305
306
307
        rel_vocab.update(
            ["--root--"]
            + [x[1] for x in self.raw_rel]
            + [x[1] + "_INV" for x in self.raw_rel]
        )
308
309
310
        text_vocab.update(self.raw_ent_type)
        text_vocab.update(self.raw_text)

311

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
class BucketSampler(torch.utils.data.Sampler):
    def __init__(self, data_source, batch_size=32, bucket=3):
        self.data_source = data_source
        self.bucket = bucket
        self.batch_size = batch_size

    def __iter__(self):
        # the magic number comes from the author's code
        perm = torch.randperm(len(self.data_source))
        lens = torch.Tensor([len(x) for x in self.data_source])
        lens = lens[perm]
        t1 = []
        t2 = []
        t3 = []
        for i, l in enumerate(lens):
327
            if l < 100:
328
                t1.append(perm[i])
329
            elif l > 100 and l < 220:
330
331
332
                t2.append(perm[i])
            else:
                t3.append(perm[i])
333
        datas = [t1, t2, t3]
334
335
336
        random.shuffle(datas)
        idxs = sum(datas, [])
        batch = []
337

338
        lens = torch.Tensor([len(x) for x in self.data_source])
339
340
        for idx in idxs:
            batch.append(idx)
341
342
343
344
345
346
347
            mlen = max([0] + [lens[x] for x in batch])
            if (
                (mlen < 100 and len(batch) == 32)
                or (mlen > 100 and mlen < 220 and len(batch) >= 24)
                or (mlen > 220 and len(batch) >= 8)
                or len(batch) == 32
            ):
348
349
350
351
352
353
                yield batch
                batch = []
        if len(batch) > 0:
            yield batch

    def __len__(self):
354
355
        return (len(self.data_source) + self.batch_size - 1) // self.batch_size

356
357

class GWdataset(torch.utils.data.Dataset):
358
359
360
361
362
363
364
365
366
367
    def __init__(
        self,
        exs,
        ent_vocab=None,
        rel_vocab=None,
        text_vocab=None,
        ent_text_vocab=None,
        title_vocab=None,
        device=None,
    ):
368
369
        super(GWdataset, self).__init__()
        self.exs = exs
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        (
            self.ent_vocab,
            self.rel_vocab,
            self.text_vocab,
            self.ent_text_vocab,
            self.title_vocab,
            self.device,
        ) = (
            ent_vocab,
            rel_vocab,
            text_vocab,
            ent_text_vocab,
            title_vocab,
            device,
        )
385
386
387
388
389
390
391
392
393
394
395

    def __iter__(self):
        return iter(self.exs)

    def __getitem__(self, index):
        return self.exs[index]

    def __len__(self):
        return len(self.exs)

    def batch_fn(self, batch_ex):
396
397
398
399
400
401
402
403
404
        (
            batch_title,
            batch_ent_text,
            batch_ent_type,
            batch_rel,
            batch_text,
            batch_tgt_text,
            batch_graph,
        ) = ([], [], [], [], [], [], [])
405
406
        batch_raw_ent_text = []
        for ex in batch_ex:
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
            ex_data = ex.get_tensor(
                self.ent_vocab,
                self.rel_vocab,
                self.text_vocab,
                self.ent_text_vocab,
                self.title_vocab,
            )
            batch_title.append(ex_data["title"])
            batch_ent_text.append(ex_data["ent_text"])
            batch_ent_type.append(ex_data["ent_type"])
            batch_rel.append(ex_data["rel"])
            batch_text.append(ex_data["text"])
            batch_tgt_text.append(ex_data["tgt_text"])
            batch_graph.append(ex_data["graph"])
            batch_raw_ent_text.append(ex_data["raw_ent_text"])
        batch_title = pad(batch_title, out_type="tensor")
        batch_ent_text, ent_len = pad(
            batch_ent_text, out_type="tensor", flatten=True
        )
        batch_ent_type = pad(batch_ent_type, out_type="tensor")
        batch_rel = pad(batch_rel, out_type="tensor")
        batch_text = pad(batch_text, out_type="tensor")
        batch_tgt_text = pad(batch_tgt_text, out_type="tensor")
430
431
        batch_graph = dgl.batch(batch_graph)
        batch_graph.to(self.device)
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
        return {
            "title": batch_title.to(self.device),
            "ent_text": batch_ent_text.to(self.device),
            "ent_len": ent_len,
            "ent_type": batch_ent_type.to(self.device),
            "rel": batch_rel.to(self.device),
            "text": batch_text.to(self.device),
            "tgt_text": batch_tgt_text.to(self.device),
            "graph": batch_graph,
            "raw_ent_text": batch_raw_ent_text,
        }


def get_datasets(
    fnames,
    min_freq=-1,
    sep=";",
    joint_vocab=True,
    device=None,
    save="tmp.pickle",
):
453
454
455
    # min_freq : not support now since it's very sensitive to the final results, but you can set it via passing min_freq to the Vocab class.
    # sep : not support now
    # joint_vocab : not support now
456
457
458
    ent_vocab = Vocab(sp=["<PAD>", "<UNK>"])
    title_vocab = Vocab(min_freq=5)
    rel_vocab = Vocab(sp=["<PAD>", "<UNK>"])
459
    text_vocab = Vocab(min_freq=5)
460
    ent_text_vocab = Vocab(sp=["<PAD>", "<UNK>"])
461
462
463
464
465
466
467
    datasets = []
    for fname in fnames:
        exs = []
        json_datas = json.loads(open(fname).read())
        for json_data in json_datas:
            # construct one data example
            ex = Example.from_json(json_data)
468
469
470
471
472
473
474
475
            if fname == fnames[0]:  # only training set
                ex.update_vocab(
                    ent_vocab,
                    rel_vocab,
                    text_vocab,
                    ent_text_vocab,
                    title_vocab,
                )
476
477
478
479
480
481
482
            exs.append(ex)
        datasets.append(exs)
    ent_vocab.build()
    rel_vocab.build()
    text_vocab.build()
    ent_text_vocab.build()
    title_vocab.build()
483
484
485
486
487
488
489
490
491
492
493
494
495
    datasets = [
        GWdataset(
            exs,
            ent_vocab,
            rel_vocab,
            text_vocab,
            ent_text_vocab,
            title_vocab,
            device,
        )
        for exs in datasets
    ]
    with open(save, "wb") as f:
496
497
498
499
        pickle.dump(datasets, f)
    return datasets


500
501
502
503
504
505
506
507
if __name__ == "__main__":
    ds = get_datasets(
        [
            "data/unprocessed.val.json",
            "data/unprocessed.val.json",
            "data/unprocessed.test.json",
        ]
    )
508
    print(ds[0].exs[0])
509
510
511
512
513
514
515
516
517
518
519
    print(
        ds[0]
        .exs[0]
        .get_tensor(
            ds[0].ent_vocab,
            ds[0].rel_vocab,
            ds[0].text_vocab,
            ds[0].ent_text_vocab,
            ds[0].title_vocab,
        )
    )