utils_summarization.py 5.68 KB
Newer Older
Rémi Louf's avatar
Rémi Louf committed
1
2
3
4
5
6
7
8
9
10
11
12
from collections import deque
import os

import torch
from torch.utils.data import Dataset


# ------------
# Data loading
# ------------


Rémi Louf's avatar
Rémi Louf committed
13
class SummarizationDataset(Dataset):
Rémi Louf's avatar
Rémi Louf committed
14
15
    """ Abstracts the dataset used to train seq2seq models.

Rémi Louf's avatar
Rémi Louf committed
16
17
18
19
20
    The class will process the documents that are located in the specified
    folder. The preprocessing will work on any document that is reasonably
    formatted. On the CNN/DailyMail dataset it will extract both the story
    and the summary.

Rémi Louf's avatar
Rémi Louf committed
21
22
23
24
25
26
27
28
29
30
31
32
    CNN/Daily News:

    The CNN/Daily News raw datasets are downloaded from [1]. The stories are
    stored in different files; the summary appears at the end of the story as
    sentences that are prefixed by the special `@highlight` line. To process
    the data, untar both datasets in the same folder, and pass the path to this
    folder as the "data_dir argument. The formatting code was inspired by [2].

    [1] https://cs.nyu.edu/~kcho/
    [2] https://github.com/abisee/cnn-dailymail/
    """

Rémi Louf's avatar
Rémi Louf committed
33
34
35
36
37
38
39
40
41
    def __init__(self, path="", prefix="train"):
        """ We initialize the class by listing all the documents to summarize.
        Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
        """
        assert os.path.isdir(path)

        self.documents = []
        story_filenames_list = os.listdir(path)
        for story_filename in story_filenames_list:
42
43
            if "summary" in story_filename:
                continue
Rémi Louf's avatar
Rémi Louf committed
44
45
46
47
            path_to_story = os.path.join(path, story_filename)
            if not os.path.isfile(path_to_story):
                continue
            self.documents.append(path_to_story)
Rémi Louf's avatar
Rémi Louf committed
48
49

    def __len__(self):
Rémi Louf's avatar
Rémi Louf committed
50
51
        """ Returns the number of documents. """
        return len(self.documents)
Rémi Louf's avatar
Rémi Louf committed
52
53

    def __getitem__(self, idx):
Rémi Louf's avatar
Rémi Louf committed
54
55
56
        document_path = self.documents[idx]
        document_name = document_path.split("/")[-1]
        with open(document_path, encoding="utf-8") as source:
Rémi Louf's avatar
Rémi Louf committed
57
58
            raw_story = source.read()
            story_lines, summary_lines = process_story(raw_story)
Rémi Louf's avatar
Rémi Louf committed
59
        return document_name, story_lines, summary_lines
Rémi Louf's avatar
Rémi Louf committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88


def process_story(raw_story):
    """ Extract the story and summary from a story file.

    Attributes:
        raw_story (str): content of the story file as an utf-8 encoded string.

    Raises:
        IndexError: If the stoy is empty or contains no highlights.
    """
    nonempty_lines = list(
        filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
    )

    # for some unknown reason some lines miss a period, add it
    nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]

    # gather article lines
    story_lines = []
    lines = deque(nonempty_lines)
    while True:
        try:
            element = lines.popleft()
            if element.startswith("@highlight"):
                break
            story_lines.append(element)
        except IndexError:
            # if "@highlight" is absent from the file we pop
Rémi Louf's avatar
Rémi Louf committed
89
            # all elements until there is None, raising an exception.
Rémi Louf's avatar
Rémi Louf committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
            return story_lines, []

    # gather summary lines
    summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))

    return story_lines, summary_lines


def _add_missing_period(line):
    END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
    if line.startswith("@highlight"):
        return line
    if line[-1] in END_TOKENS:
        return line
    return line + "."


# --------------------------
# Encoding and preprocessing
# --------------------------


Rémi Louf's avatar
Rémi Louf committed
112
def fit_to_block_size(sequence, block_size, pad_token_id):
Rémi Louf's avatar
Rémi Louf committed
113
    """ Adapt the source and target sequences' lengths to the block size.
Rémi Louf's avatar
Rémi Louf committed
114
    If the sequence is shorter we append padding token to the right of the sequence.
Rémi Louf's avatar
Rémi Louf committed
115
116
117
118
    """
    if len(sequence) > block_size:
        return sequence[:block_size]
    else:
Rémi Louf's avatar
Rémi Louf committed
119
        sequence.extend([pad_token_id] * (block_size - len(sequence)))
Rémi Louf's avatar
Rémi Louf committed
120
121
122
        return sequence


Rémi Louf's avatar
Rémi Louf committed
123
def build_mask(sequence, pad_token_id):
Rémi Louf's avatar
Rémi Louf committed
124
125
    """ Builds the mask. The attention mechanism will only attend to positions
    with value 1. """
126
    mask = torch.ones_like(sequence)
Rémi Louf's avatar
Rémi Louf committed
127
    idx_pad_tokens = sequence == pad_token_id
128
    mask[idx_pad_tokens] = 0
Rémi Louf's avatar
Rémi Louf committed
129
130
131
132
133
134
135
136
    return mask


def encode_for_summarization(story_lines, summary_lines, tokenizer):
    """ Encode the story and summary lines, and join them
    as specified in [1] by using `[SEP] [CLS]` tokens to separate
    sentences.
    """
Rémi Louf's avatar
Rémi Louf committed
137
    story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
Rémi Louf's avatar
Rémi Louf committed
138
139
140
    story_token_ids = [
        token for sentence in story_lines_token_ids for token in sentence
    ]
Rémi Louf's avatar
Rémi Louf committed
141
    summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
Rémi Louf's avatar
Rémi Louf committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    summary_token_ids = [
        token for sentence in summary_lines_token_ids for token in sentence
    ]

    return story_token_ids, summary_token_ids


def compute_token_type_ids(batch, separator_token_id):
    """ Segment embeddings as described in [1]

    The values {0,1} were found in the repository [2].

    Attributes:
        batch: torch.Tensor, size [batch_size, block_size]
            Batch of input.
        separator_token_id: int
            The value of the token that separates the segments.

    [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
        arXiv preprint arXiv:1908.08345 (2019).
    [2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
    """
    batch_embeddings = []
    for sequence in batch:
Rémi Louf's avatar
Rémi Louf committed
166
        sentence_num = -1
Rémi Louf's avatar
Rémi Louf committed
167
168
169
170
171
172
173
        embeddings = []
        for s in sequence:
            if s == separator_token_id:
                sentence_num += 1
            embeddings.append(sentence_num % 2)
        batch_embeddings.append(embeddings)
    return torch.tensor(batch_embeddings)