utils.py 2.18 KB
Newer Older
1
2
import os

3
import torch
4
5
from torch.utils.data import Dataset

6
from transformers.tokenization_utils import trim_batch
7
8


9
10
11
12
13
14
15
16
17
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
    examples = []
    with open(data_path, "r") as f:
        for text in f.readlines():
            tokenized = tokenizer.batch_encode_plus(
                [text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors,
            )
            examples.append(tokenized)
    return examples
18
19


20
21
22
23
24
25
26
27
28
29
30
31
32
class SummarizationDataset(Dataset):
    def __init__(
        self,
        tokenizer,
        data_dir="./cnn-dailymail/cnn_dm/",
        type_path="train",
        max_source_length=1024,
        max_target_length=56,
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
        self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
33
34
35
36
37
38
39

    def __len__(self):
        return len(self.source)

    def __getitem__(self, index):
        source_ids = self.source[index]["input_ids"].squeeze()
        target_ids = self.target[index]["input_ids"].squeeze()
40
        src_mask = self.source[index]["attention_mask"].squeeze()
41
        return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

    @staticmethod
    def trim_seq2seq_batch(batch, pad_token_id):
        y = trim_batch(batch["target_ids"], pad_token_id)
        source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"])
        return source_ids, source_mask, y

    def collate_fn(self, batch):
        input_ids = torch.stack([x["source_ids"] for x in batch])
        masks = torch.stack([x["source_mask"] for x in batch])
        target_ids = torch.stack([x["target_ids"] for x in batch])
        pad_token_id = self.tokenizer.pad_token_id
        y = trim_batch(target_ids, pad_token_id)
        source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
        return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y}