run_seq2seq_finetuning.py 9.39 KB
Newer Older
Rémi Louf's avatar
Rémi Louf committed
1
# coding=utf-8
2
# Copyright 2018 The Microsoft Reseach team and The HuggingFace Inc. team.
Rémi Louf's avatar
Rémi Louf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) 2018 Microsoft and The HuggingFace Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning seq2seq models for sequence generation.

We use the procedure described in [1] to finetune models for sequence
generation. Let S1 and S2 be the source and target sequence respectively; we
20
pack them using the start of sequence [EOS] and end of sequence [EOS] token:
Rémi Louf's avatar
Rémi Louf committed
21

22
    [CLS] S1 [EOS] S2 [EOS]
Rémi Louf's avatar
Rémi Louf committed
23
24
25
26
27
28
29
30
31
32

We then mask a fixed percentage of token from S2 at random and learn to predict
the masked words. [EOS] can be masked during finetuning so the model learns to
terminate the generation process.

[1] Dong Li, Nan Yang, Wenhui Wang, Furu Wei, Xiaodong Liu, Yu Wang, Jianfeng
Gao, Ming Zhou, and Hsiao-Wuen Hon.  “Unified Language Model Pre-Training for
Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
"""

33
import argparse
Rémi Louf's avatar
Rémi Louf committed
34
from collections import deque
Rémi Louf's avatar
Rémi Louf committed
35
import logging
36
import pickle
Rémi Louf's avatar
Rémi Louf committed
37
import random
38
import os
Rémi Louf's avatar
Rémi Louf committed
39
40
41

import numpy as np
import torch
42
from torch.utils.data import Dataset
Rémi Louf's avatar
Rémi Louf committed
43

44
from transformers import BertTokenizer
45

Rémi Louf's avatar
Rémi Louf committed
46
47
48
49
50
51
52
53
54
logger = logging.getLogger(__name__)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


55
56
57
58
# ------------
# Load dataset
# ------------

59
class TextDataset(Dataset):
60
    """ Abstracts the dataset used to train seq2seq models.
61
62
63

    CNN/Daily News:

64
65
66
67
    The CNN/Daily News raw datasets are downloaded from [1]. The stories are
    stored in different files; the summary appears at the end of the story as
    sentences that are prefixed by the special `@highlight` line. To process
    the data, untar both datasets in the same folder, and pass the path to this
68
    folder as the "data_dir argument. The formatting code was inspired by [2].
69

70
71
    [1] https://cs.nyu.edu/~kcho/
    [2] https://github.com/abisee/cnn-dailymail/
72
    """
73
74

    def __init_(self, tokenizer, data_dir="", block_size=512):
75
        assert os.path.isdir(data_dir)
76

77
        # Load features that have already been computed if present
78
79
80
        cached_features_file = os.path.join(
            data_dir, "cached_lm_{}_{}".format(block_size, data_dir)
        )
81
82
83
84
        if os.path.exists(cached_features_file):
            logger.info("Loading features from cached file %s", cached_features_file)
            with open(cached_features_file, "rb") as source:
                self.examples = pickle.load(source)
85
86
                return

87
        logger.info("Creating features from dataset at %s", data_dir)
88

89
        datasets = ["cnn", "dailymail"]
90
91
92
93
        for dataset in datasets:
            path_to_stories = os.path.join(data_dir, dataset, "stories")
            assert os.path.isdir(path_to_stories)

Rémi Louf's avatar
Rémi Louf committed
94
95
96
            story_filenames_list = os.listdir(path_to_stories)
            for story_filename in story_filenames_list:
                path_to_story = os.path.join(path_to_stories, story_filename)
Rémi Louf's avatar
Rémi Louf committed
97
                if not os.path.isfile(path_to_story):
98
99
100
101
                    continue

                with open(path_to_story, encoding="utf-8") as source:
                    try:
Rémi Louf's avatar
Rémi Louf committed
102
103
                        raw_story = source.read()
                        story, summary = process_story(raw_story)
Rémi Louf's avatar
Rémi Louf committed
104
                    except IndexError:  # skip ill-formed stories
105
106
                        continue

107
108
                story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
                summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
Rémi Louf's avatar
Rémi Louf committed
109
                story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
Rémi Louf's avatar
Rémi Louf committed
110
111
112

                self.examples.append(
                    tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
113
                )
114
115
116
117

        logger.info("Saving features into cache file %s", cached_features_file)
        with open(cached_features_file, "wb") as sink:
            pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL)
118
119
120
121

    def __len__(self):
        return len(self.examples)

Rémi Louf's avatar
Rémi Louf committed
122
    def __getitem__(self, items):
123
124
125
        return torch.tensor(self.examples[items])


Rémi Louf's avatar
Rémi Louf committed
126
def process_story(raw_story):
Rémi Louf's avatar
Rémi Louf committed
127
128
129
130
131
132
133
    """ Extract the story and summary from a story file.

    Attributes:
        raw_story (str): content of the story file as an utf-8 encoded string.

    Raises:
        IndexError: If the stoy is empty or contains no highlights.
134
    """
135
136
137
    file_lines = list(
        filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
    )
138
139
140
141
142
143

    # for some unknown reason some lines miss a period, add it
    file_lines = [_add_missing_period(line) for line in file_lines]

    # gather article lines
    story_lines = []
Rémi Louf's avatar
Rémi Louf committed
144
    lines = deque(file_lines)
145
146
147
148
149
150
151
152
153
154
    while True:
        try:
            element = lines.popleft()
            if element.startswith("@highlight"):
                break
            story_lines.append(element)
        except IndexError as ie:  # if "@highlight" absent from file
            raise ie

    # gather summary lines
Rémi Louf's avatar
Rémi Louf committed
155
    highlights_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
156
157
158
159
160
161
162
163
164

    # join the lines
    story = " ".join(story_lines)
    summary = " ".join(highlights_lines)

    return story, summary


def _add_missing_period(line):
165
    END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
Rémi Louf's avatar
Rémi Louf committed
166
    if line.startswith("@highlight"):
167
168
169
        return line
    if line[-1] in END_TOKENS:
        return line
Rémi Louf's avatar
Rémi Louf committed
170
    return line + "."
171
172


173
def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
174
    """ Adapt the source and target sequences' lengths to the block size.
175

176
177
178
179
    If the concatenated sequence (source + target + 3 special tokens) would be
    longer than the block size we use the 75% / 25% rule followed in [1]. For a
    block size of 512 this means limiting the source sequence's length to 384
    and the target sequence's length to 128.
180

Rémi Louf's avatar
Rémi Louf committed
181
182
183
184
185
186
187
    Attributes:
        src_sequence (list): a list of ids that maps to the tokens of the
            source sequence.
        tgt_sequence (list): a list of ids that maps to the tokens of the
            target sequence.
        block_size (int): the model's block size.

188
189
190
    [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
    Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
    """
Rémi Louf's avatar
Rémi Louf committed
191
    SRC_MAX_LENGTH = int(0.75 * block_size) - 2  # CLS and EOS token
192
    TGT_MAX_LENGTH = block_size - (SRC_MAX_LENGTH + 2) - 1  # EOS token
193

194
    # We dump the examples that are too small to fit in the block size for the
195
    # sake of simplicity. You can modify this by adding model-specific padding.
196
    if len(src_sequence) + len(tgt_sequence) + 3 < block_size:
197
198
        return None

Rémi Louf's avatar
Rémi Louf committed
199
    if len(src_sequence) > SRC_MAX_LENGTH:
200
201
202
        if len(tgt_sequence) > TGT_MAX_LENGTH:
            src_sequence = src_sequence[:SRC_MAX_LENGTH]
            tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
203
        else:
204
205
            remain_size = block_size - len(tgt_sequence) - 3
            src_sequence = src_sequence[:remain_size]
206
    else:
Rémi Louf's avatar
Rémi Louf committed
207
        if len(tgt_sequence) > TGT_MAX_LENGTH:
208
209
            remain_size = block_size - len(src_sequence) - 3
            tgt_sequence = tgt_sequence[:remain_size]
210

211
    return src_sequence, tgt_sequence
212
213


214
def load_and_cache_examples(args, tokenizer):
Rémi Louf's avatar
Rémi Louf committed
215
    dataset = TextDataset(tokenizer, file_path=args.data_dir)
216
    return dataset
Rémi Louf's avatar
Rémi Louf committed
217
218


219
220
221
222
223
# ------------
# Train
# ------------


224
225
def train(args, train_dataset, model, tokenizer):
    """ Fine-tune the pretrained model on the corpus. """
Rémi Louf's avatar
Rémi Louf committed
226
227
228
229
    raise NotImplementedError


def main():
230
231
232
    parser = argparse.ArgumentParser()

    # Required parameters
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input training data file (a text file).",
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
247
248

    # Optional parameters
249
250
251
252
253
254
    parser.add_argument(
        "--model_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The model checkpoint for weights initialization.",
    )
255
256
257
258
    parser.add_argument("--seed", default=42, type=int)
    args = parser.parse_args()

    # Set up training device
259
    # device = torch.device("cpu")
260
261
262
263
264

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
265
266
    tokenizer_class = BertTokenizer
    # config = config_class.from_pretrained(args.model_name_or_path)
267
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
268
269
    # model = model_class.from_pretrained(args.model_name_or_path, config=config)
    # model.to(device)
270
271
272
273

    logger.info("Training/evaluation parameters %s", args)

    # Training
274
275
276
    _ = load_and_cache_examples(args, tokenizer)
    # global_step, tr_loss = train(args, train_dataset, model, tokenizer)
    # logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
Rémi Louf's avatar
Rémi Louf committed
277
278


279
if __name__ == "__main__":
280
    main()