utils_summarization.py 5.67 KB
Newer Older
Rémi Louf's avatar
Rémi Louf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from collections import deque
import os

import torch
from torch.utils.data import Dataset


# ------------
# Data loading
# ------------


class CNNDailyMailDataset(Dataset):
    """ Abstracts the dataset used to train seq2seq models.

    CNN/Daily News:

    The CNN/Daily News raw datasets are downloaded from [1]. The stories are
    stored in different files; the summary appears at the end of the story as
    sentences that are prefixed by the special `@highlight` line. To process
    the data, untar both datasets in the same folder, and pass the path to this
    folder as the "data_dir argument. The formatting code was inspired by [2].

    [1] https://cs.nyu.edu/~kcho/
    [2] https://github.com/abisee/cnn-dailymail/
    """

Rémi Louf's avatar
Rémi Louf committed
28
    def __init__(self, data_dir="", prefix="train"):
Rémi Louf's avatar
Rémi Louf committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        assert os.path.isdir(data_dir)

        # We initialize the class by listing all the files that contain
        # stories and summaries. Files are not read in memory given
        # the size of the corpus.
        self.stories_path = []
        datasets = ("cnn", "dailymail")
        for dataset in datasets:
            path_to_stories = os.path.join(data_dir, dataset, "stories")
            story_filenames_list = os.listdir(path_to_stories)
            for story_filename in story_filenames_list:
                path_to_story = os.path.join(path_to_stories, story_filename)
                if not os.path.isfile(path_to_story):
                    continue
                self.stories_path.append(path_to_story)

    def __len__(self):
        return len(self.stories_path)

    def __getitem__(self, idx):
        story_path = self.stories_path[idx]
        with open(story_path, encoding="utf-8") as source:
            raw_story = source.read()
            story_lines, summary_lines = process_story(raw_story)
        return story_lines, summary_lines


def process_story(raw_story):
    """ Extract the story and summary from a story file.

    Attributes:
        raw_story (str): content of the story file as an utf-8 encoded string.

    Raises:
        IndexError: If the stoy is empty or contains no highlights.
    """
    nonempty_lines = list(
        filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
    )

    # for some unknown reason some lines miss a period, add it
    nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]

    # gather article lines
    story_lines = []
    lines = deque(nonempty_lines)
    while True:
        try:
            element = lines.popleft()
            if element.startswith("@highlight"):
                break
            story_lines.append(element)
        except IndexError:
            # if "@highlight" is absent from the file we pop
            # all elements until there is None.
            return story_lines, []

    # gather summary lines
    summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))

    return story_lines, summary_lines


def _add_missing_period(line):
    END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
    if line.startswith("@highlight"):
        return line
    if line[-1] in END_TOKENS:
        return line
    return line + "."


# --------------------------
# Encoding and preprocessing
# --------------------------


Rémi Louf's avatar
Rémi Louf committed
106
def fit_to_block_size(sequence, block_size, pad_token_id):
Rémi Louf's avatar
Rémi Louf committed
107
    """ Adapt the source and target sequences' lengths to the block size.
Rémi Louf's avatar
Rémi Louf committed
108
    If the sequence is shorter we append padding token to the right of the sequence.
Rémi Louf's avatar
Rémi Louf committed
109
110
111
112
    """
    if len(sequence) > block_size:
        return sequence[:block_size]
    else:
Rémi Louf's avatar
Rémi Louf committed
113
        sequence.extend([pad_token_id] * (block_size - len(sequence)))
Rémi Louf's avatar
Rémi Louf committed
114
115
116
        return sequence


Rémi Louf's avatar
Rémi Louf committed
117
118
def build_lm_labels(sequence, pad_token_id):
    """ Padding token are replaced by the value -1 so they
Rémi Louf's avatar
Rémi Louf committed
119
120
    are not taken into account in the loss computation. """
    padded = sequence.clone()
Rémi Louf's avatar
Rémi Louf committed
121
    padded[padded == pad_token_id] = -1
Rémi Louf's avatar
Rémi Louf committed
122
123
124
    return padded


Rémi Louf's avatar
Rémi Louf committed
125
def build_mask(sequence, pad_token_id):
Rémi Louf's avatar
Rémi Louf committed
126
127
    """ Builds the mask. The attention mechanism will only attend to positions
    with value 1. """
128
    mask = torch.ones_like(sequence)
Rémi Louf's avatar
Rémi Louf committed
129
    idx_pad_tokens = sequence == pad_token_id
130
    mask[idx_pad_tokens] = 0
Rémi Louf's avatar
Rémi Louf committed
131
132
133
134
135
136
137
138
    return mask


def encode_for_summarization(story_lines, summary_lines, tokenizer):
    """ Encode the story and summary lines, and join them
    as specified in [1] by using `[SEP] [CLS]` tokens to separate
    sentences.
    """
Rémi Louf's avatar
Rémi Louf committed
139
    story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
Rémi Louf's avatar
Rémi Louf committed
140
141
142
    story_token_ids = [
        token for sentence in story_lines_token_ids for token in sentence
    ]
Rémi Louf's avatar
Rémi Louf committed
143
    summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
Rémi Louf's avatar
Rémi Louf committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    summary_token_ids = [
        token for sentence in summary_lines_token_ids for token in sentence
    ]

    return story_token_ids, summary_token_ids


def compute_token_type_ids(batch, separator_token_id):
    """ Segment embeddings as described in [1]

    The values {0,1} were found in the repository [2].

    Attributes:
        batch: torch.Tensor, size [batch_size, block_size]
            Batch of input.
        separator_token_id: int
            The value of the token that separates the segments.

    [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
        arXiv preprint arXiv:1908.08345 (2019).
    [2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
    """
    batch_embeddings = []
    for sequence in batch:
        sentence_num = 0
        embeddings = []
        for s in sequence:
            if s == separator_token_id:
                sentence_num += 1
            embeddings.append(sentence_num % 2)
        batch_embeddings.append(embeddings)
    return torch.tensor(batch_embeddings)