run_seq2seq_finetuning.py 8.9 KB
Newer Older
Rémi Louf's avatar
Rémi Louf committed
1
# coding=utf-8
2
# Copyright 2018 The Microsoft Reseach team and The HuggingFace Inc. team.
Rémi Louf's avatar
Rémi Louf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) 2018 Microsoft and The HuggingFace Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning seq2seq models for sequence generation.

We use the procedure described in [1] to finetune models for sequence
generation. Let S1 and S2 be the source and target sequence respectively; we
20
pack them using the start of sequence [EOS] and end of sequence [EOS] token:
Rémi Louf's avatar
Rémi Louf committed
21

22
    [CLS] S1 [EOS] S2 [EOS]
Rémi Louf's avatar
Rémi Louf committed
23
24
25
26
27
28
29
30
31
32

We then mask a fixed percentage of token from S2 at random and learn to predict
the masked words. [EOS] can be masked during finetuning so the model learns to
terminate the generation process.

[1] Dong Li, Nan Yang, Wenhui Wang, Furu Wei, Xiaodong Liu, Yu Wang, Jianfeng
Gao, Ming Zhou, and Hsiao-Wuen Hon.  “Unified Language Model Pre-Training for
Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
"""

33
import argparse
34
import dequeue
Rémi Louf's avatar
Rémi Louf committed
35
import logging
36
import pickle
Rémi Louf's avatar
Rémi Louf committed
37
import random
38
import os
Rémi Louf's avatar
Rémi Louf committed
39
40
41

import numpy as np
import torch
42
from torch.utils.data import Dataset
Rémi Louf's avatar
Rémi Louf committed
43

44
45
from transformers import BertConfig, Bert2Rnd, BertTokenizer

Rémi Louf's avatar
Rémi Louf committed
46
47
48
49
50
51
52
53
54
logger = logging.getLogger(__name__)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


55
class TextDataset(Dataset):
56
    """ Abstracts the dataset used to train seq2seq models.
57
58
59

    CNN/Daily News:

60
61
    The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored
    in different files where the summary sentences are indicated by the special `@highlight` token.
62
63
    To process the data, untar both datasets in the same folder, and pass the path to this
    folder as the "data_dir argument. The formatting code was inspired by [2].
64

65
66
    [1] https://cs.nyu.edu/~kcho/
    [2] https://github.com/abisee/cnn-dailymail/
67
    """
68
69
    def __init_(self, tokenizer, data_dir='', block_size=512):
        assert os.path.isdir(data_dir)
70

71
72
        # Load features that have already been computed if present
        cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir)
73
74
75
76
        if os.path.exists(cached_features_file):
            logger.info("Loading features from cached file %s", cached_features_file)
            with open(cached_features_file, "rb") as source:
                self.examples = pickle.load(source)
77
78
                return

79
        logger.info("Creating features from dataset at %s", data_dir)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

        datasets = ['cnn', 'dailymail']
        for dataset in datasets:
            path_to_stories = os.path.join(data_dir, dataset, "stories")
            assert os.path.isdir(path_to_stories)

            stories_files = os.listdir(path_to_stories)
            for story_file in stories_files:
                path_to_story = os.path.join(path_to_stories, "story_file")
                if !os.path.isfile(path_to_story):
                    continue

                with open(path_to_story, encoding="utf-8") as source:
                    try:
                        story, summary = process_story(source)
                    except IndexError:
                        continue

98
99
100
101
                story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
                summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
                story_seq, summary_seq = _fit_to_block_size(story, summary, blocksize)
                example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
102
103
104
105
106
                self.examples.append(example)

        logger.info("Saving features into cache file %s", cached_features_file)
        with open(cached_features_file, "wb") as sink:
            pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL)
107
108
109
110
111
112
113
114

    def __len__(self):
        return len(self.examples)

    def __getitem__(self):
        return torch.tensor(self.examples[items])


115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def process_story(story_file):
    """ Process the text contained in a story file.
    Returns the story and the summary
    """
    file_lines = list(filter(lambda x: len(x)!=0, [line.strip() for lines in story_file]))

    # for some unknown reason some lines miss a period, add it
    file_lines = [_add_missing_period(line) for line in file_lines]

    # gather article lines
    story_lines = []
    lines = dequeue(file_lines)
    while True:
        try:
            element = lines.popleft()
            if element.startswith("@highlight"):
                break
            story_lines.append(element)
        except IndexError as ie:  # if "@highlight" absent from file
            raise ie

    # gather summary lines
    highlights_lines = list(filter(lambda t: !t.startswith("@highlight"), lines))

    # join the lines
    story = " ".join(story_lines)
    summary = " ".join(highlights_lines)

    return story, summary


def _add_missing_period(line):
    END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"]
    if line == "@highlight":
        return line
    if line[-1] in END_TOKENS:
        return line
    return line + " ."


155
def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
156
157
    """ Concatenate the sequences and adapt their lengths to the block size.

158
159
160
161
    Following [1] we truncate the source and target + tokens sequences so they fit
    in the block size. If the concatenated sequence is longer than 512 we follow
    the 75%/25% rule in [1]: limit the source sequence's length to 384 and the
    target sequence's length to 128.
162
163
164
165
166
167
168
169
170

    [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
    Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
    """
    SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
    TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token

    # we dump the examples that are too small to fit in the block size for the
    # sake of simplicity. You can modify this by adding model-specific padding.
171
    if len(src_sequence) + len(src_sequence) + 3 < block_size:
172
173
174
        return None

    # the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
175
176
177
178
    if len(src_sequence) > SRC_MAX_LENGTH
        if len(tgt_sequence) > TGT_MAX_LENGTH:
            src_sequence = src_sequence[:SRC_MAX_LENGTH]
            tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
179
        else:
180
            src_sequence = src_sequence[block_size - len(tgt_sequence) - 3]
181
182
    else:
        if len(tgt_tokens) > TGT_MAX_LENGTH:
183
            tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3]
184

185
    return src_sequence, tgt_sequence
186
187
188



189
def load_and_cache_examples(args, tokenizer):
190
191
    dataset = TextDataset(tokenizer, file_path=args.train_data_file)
    return dataset
Rémi Louf's avatar
Rémi Louf committed
192
193


194
195
def train(args, train_dataset, model, tokenizer):
    """ Fine-tune the pretrained model on the corpus. """
Rémi Louf's avatar
Rémi Louf committed
196
197
198
199
    raise NotImplementedError


def main():
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--train_data_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input training data file (a text file).")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    # Optional parameters
    parser.add_argument("--model_name_or_path",
                        default="bert-base-cased",
                        type=str,
                        help="The model checkpoint for weights initialization.")
    parser.add_argument("--seed", default=42, type=int)
    args = parser.parse_args()

    # Set up training device
    device = torch.device("cpu")

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    config_class, model_class, tokenizer_class = BertConfig, Bert2Rnd, BertTokenizer
    config = config_class.from_pretrained(args.model_name_or_path)
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    model = model_class.from_pretrained(args.model_name_or_path, config=config)
    model.to(device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    train_dataset = load_and_cache_examples(args, tokenizer)
    global_step, tr_loss = train(args, train_dataset, model, tokenizer)
    logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
Rémi Louf's avatar
Rémi Louf committed
241
242


243
if __name__ == "__main__":
244
    main()