utils_summarization.py 5.64 KB
Newer Older
Rémi Louf's avatar
Rémi Louf committed
1
import os
Aymeric Augustin's avatar
Aymeric Augustin committed
2
from collections import deque
Rémi Louf's avatar
Rémi Louf committed
3
4
5
6
7
8
9
10
11
12

import torch
from torch.utils.data import Dataset


# ------------
# Data loading
# ------------


Rémi Louf's avatar
Rémi Louf committed
13
class SummarizationDataset(Dataset):
Rémi Louf's avatar
Rémi Louf committed
14
15
    """ Abstracts the dataset used to train seq2seq models.

Rémi Louf's avatar
Rémi Louf committed
16
17
18
19
20
    The class will process the documents that are located in the specified
    folder. The preprocessing will work on any document that is reasonably
    formatted. On the CNN/DailyMail dataset it will extract both the story
    and the summary.

Rémi Louf's avatar
Rémi Louf committed
21
22
23
24
25
26
27
28
29
30
31
32
    CNN/Daily News:

    The CNN/Daily News raw datasets are downloaded from [1]. The stories are
    stored in different files; the summary appears at the end of the story as
    sentences that are prefixed by the special `@highlight` line. To process
    the data, untar both datasets in the same folder, and pass the path to this
    folder as the "data_dir argument. The formatting code was inspired by [2].

    [1] https://cs.nyu.edu/~kcho/
    [2] https://github.com/abisee/cnn-dailymail/
    """

Rémi Louf's avatar
Rémi Louf committed
33
34
35
36
37
38
39
40
41
    def __init__(self, path="", prefix="train"):
        """ We initialize the class by listing all the documents to summarize.
        Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
        """
        assert os.path.isdir(path)

        self.documents = []
        story_filenames_list = os.listdir(path)
        for story_filename in story_filenames_list:
42
43
            if "summary" in story_filename:
                continue
Rémi Louf's avatar
Rémi Louf committed
44
45
46
47
            path_to_story = os.path.join(path, story_filename)
            if not os.path.isfile(path_to_story):
                continue
            self.documents.append(path_to_story)
Rémi Louf's avatar
Rémi Louf committed
48
49

    def __len__(self):
Rémi Louf's avatar
Rémi Louf committed
50
51
        """ Returns the number of documents. """
        return len(self.documents)
Rémi Louf's avatar
Rémi Louf committed
52
53

    def __getitem__(self, idx):
Rémi Louf's avatar
Rémi Louf committed
54
55
56
        document_path = self.documents[idx]
        document_name = document_path.split("/")[-1]
        with open(document_path, encoding="utf-8") as source:
Rémi Louf's avatar
Rémi Louf committed
57
58
            raw_story = source.read()
            story_lines, summary_lines = process_story(raw_story)
Rémi Louf's avatar
Rémi Louf committed
59
        return document_name, story_lines, summary_lines
Rémi Louf's avatar
Rémi Louf committed
60
61
62
63
64
65
66
67
68
69
70


def process_story(raw_story):
    """ Extract the story and summary from a story file.

    Attributes:
        raw_story (str): content of the story file as an utf-8 encoded string.

    Raises:
        IndexError: If the stoy is empty or contains no highlights.
    """
71
    nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
Rémi Louf's avatar
Rémi Louf committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    # for some unknown reason some lines miss a period, add it
    nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]

    # gather article lines
    story_lines = []
    lines = deque(nonempty_lines)
    while True:
        try:
            element = lines.popleft()
            if element.startswith("@highlight"):
                break
            story_lines.append(element)
        except IndexError:
            # if "@highlight" is absent from the file we pop
Rémi Louf's avatar
Rémi Louf committed
87
            # all elements until there is None, raising an exception.
Rémi Louf's avatar
Rémi Louf committed
88
89
90
91
92
93
94
95
96
            return story_lines, []

    # gather summary lines
    summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))

    return story_lines, summary_lines


def _add_missing_period(line):
Aymeric Augustin's avatar
Aymeric Augustin committed
97
    END_TOKENS = [".", "!", "?", "...", "'", "`", '"', "\u2019", "\u2019", ")"]
Rémi Louf's avatar
Rémi Louf committed
98
99
100
101
102
103
104
105
106
107
108
109
    if line.startswith("@highlight"):
        return line
    if line[-1] in END_TOKENS:
        return line
    return line + "."


# --------------------------
# Encoding and preprocessing
# --------------------------


Rémi Louf's avatar
Rémi Louf committed
110
def fit_to_block_size(sequence, block_size, pad_token_id):
Rémi Louf's avatar
Rémi Louf committed
111
    """ Adapt the source and target sequences' lengths to the block size.
Rémi Louf's avatar
Rémi Louf committed
112
    If the sequence is shorter we append padding token to the right of the sequence.
Rémi Louf's avatar
Rémi Louf committed
113
114
115
116
    """
    if len(sequence) > block_size:
        return sequence[:block_size]
    else:
Rémi Louf's avatar
Rémi Louf committed
117
        sequence.extend([pad_token_id] * (block_size - len(sequence)))
Rémi Louf's avatar
Rémi Louf committed
118
119
120
        return sequence


Rémi Louf's avatar
Rémi Louf committed
121
def build_mask(sequence, pad_token_id):
Rémi Louf's avatar
Rémi Louf committed
122
123
    """ Builds the mask. The attention mechanism will only attend to positions
    with value 1. """
124
    mask = torch.ones_like(sequence)
Rémi Louf's avatar
Rémi Louf committed
125
    idx_pad_tokens = sequence == pad_token_id
126
    mask[idx_pad_tokens] = 0
Rémi Louf's avatar
Rémi Louf committed
127
128
129
130
131
132
133
134
    return mask


def encode_for_summarization(story_lines, summary_lines, tokenizer):
    """ Encode the story and summary lines, and join them
    as specified in [1] by using `[SEP] [CLS]` tokens to separate
    sentences.
    """
Rémi Louf's avatar
Rémi Louf committed
135
    story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
136
    story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
Rémi Louf's avatar
Rémi Louf committed
137
    summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
138
    summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
Rémi Louf's avatar
Rémi Louf committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

    return story_token_ids, summary_token_ids


def compute_token_type_ids(batch, separator_token_id):
    """ Segment embeddings as described in [1]

    The values {0,1} were found in the repository [2].

    Attributes:
        batch: torch.Tensor, size [batch_size, block_size]
            Batch of input.
        separator_token_id: int
            The value of the token that separates the segments.

    [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
        arXiv preprint arXiv:1908.08345 (2019).
    [2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
    """
    batch_embeddings = []
    for sequence in batch:
Rémi Louf's avatar
Rémi Louf committed
160
        sentence_num = -1
Rémi Louf's avatar
Rémi Louf committed
161
162
163
164
165
166
167
        embeddings = []
        for s in sequence:
            if s == separator_token_id:
                sentence_num += 1
            embeddings.append(sentence_num % 2)
        batch_embeddings.append(embeddings)
    return torch.tensor(batch_embeddings)