run_seq2seq_finetuning.py 8.1 KB
Newer Older
Rémi Louf's avatar
Rémi Louf committed
1
# coding=utf-8
2
# Copyright 2018 The Microsoft Reseach team and The HuggingFace Inc. team.
Rémi Louf's avatar
Rémi Louf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2018 Microsoft and The HuggingFace Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Rémi Louf's avatar
Rémi Louf committed
16
""" Finetuning seq2seq models for sequence generation."""
Rémi Louf's avatar
Rémi Louf committed
17

18
import argparse
Rémi Louf's avatar
Rémi Louf committed
19
from collections import deque
Rémi Louf's avatar
Rémi Louf committed
20
import logging
21
import pickle
Rémi Louf's avatar
Rémi Louf committed
22
import random
23
import os
Rémi Louf's avatar
Rémi Louf committed
24
25
26

import numpy as np
import torch
27
from torch.utils.data import Dataset
Rémi Louf's avatar
Rémi Louf committed
28

29
from transformers import AutoTokenizer, Model2Model
30

Rémi Louf's avatar
Rémi Louf committed
31
32
33
34
35
36
37
38
39
logger = logging.getLogger(__name__)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


40
41
42
43
# ------------
# Load dataset
# ------------

Rémi Louf's avatar
Rémi Louf committed
44

45
class TextDataset(Dataset):
46
    """ Abstracts the dataset used to train seq2seq models.
47
48
49

    CNN/Daily News:

50
51
52
53
    The CNN/Daily News raw datasets are downloaded from [1]. The stories are
    stored in different files; the summary appears at the end of the story as
    sentences that are prefixed by the special `@highlight` line. To process
    the data, untar both datasets in the same folder, and pass the path to this
54
    folder as the "data_dir argument. The formatting code was inspired by [2].
55

56
57
    [1] https://cs.nyu.edu/~kcho/
    [2] https://github.com/abisee/cnn-dailymail/
58
    """
59

60
    def __init_(self, tokenizer_src, tokenizer_tgt, data_dir="", block_size=512):
61
        assert os.path.isdir(data_dir)
62

63
        # Load features that have already been computed if present
64
65
66
        cached_features_file = os.path.join(
            data_dir, "cached_lm_{}_{}".format(block_size, data_dir)
        )
67
68
69
70
        if os.path.exists(cached_features_file):
            logger.info("Loading features from cached file %s", cached_features_file)
            with open(cached_features_file, "rb") as source:
                self.examples = pickle.load(source)
71
72
                return

73
        logger.info("Creating features from dataset at %s", data_dir)
74

75
        datasets = ["cnn", "dailymail"]
76
77
78
79
        for dataset in datasets:
            path_to_stories = os.path.join(data_dir, dataset, "stories")
            assert os.path.isdir(path_to_stories)

Rémi Louf's avatar
Rémi Louf committed
80
81
82
            story_filenames_list = os.listdir(path_to_stories)
            for story_filename in story_filenames_list:
                path_to_story = os.path.join(path_to_stories, story_filename)
Rémi Louf's avatar
Rémi Louf committed
83
                if not os.path.isfile(path_to_story):
84
85
86
87
                    continue

                with open(path_to_story, encoding="utf-8") as source:
                    try:
Rémi Louf's avatar
Rémi Louf committed
88
89
                        raw_story = source.read()
                        story, summary = process_story(raw_story)
Rémi Louf's avatar
Rémi Louf committed
90
                    except IndexError:  # skip ill-formed stories
91
92
                        continue

93
                story = tokenizer_src.convert_tokens_to_ids(tokenizer_src.tokenize(story))
94
                story_seq = _fit_to_block_size(story, block_size)
Rémi Louf's avatar
Rémi Louf committed
95

96
97
98
99
                summary = tokenizer_tgt.convert_tokens_to_ids(tokenizer_tgt.tokenize(summary))
                summary_seq = _fit_to_block_size(summary, block_size)

                self.examples.append((story_seq, summary_seq))
100
101
102
103

        logger.info("Saving features into cache file %s", cached_features_file)
        with open(cached_features_file, "wb") as sink:
            pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL)
104
105
106
107

    def __len__(self):
        return len(self.examples)

Rémi Louf's avatar
Rémi Louf committed
108
    def __getitem__(self, items):
109
110
111
        return torch.tensor(self.examples[items])


Rémi Louf's avatar
Rémi Louf committed
112
def process_story(raw_story):
Rémi Louf's avatar
Rémi Louf committed
113
114
115
116
117
118
119
    """ Extract the story and summary from a story file.

    Attributes:
        raw_story (str): content of the story file as an utf-8 encoded string.

    Raises:
        IndexError: If the stoy is empty or contains no highlights.
120
    """
121
122
123
    file_lines = list(
        filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
    )
124
125
126
127
128
129

    # for some unknown reason some lines miss a period, add it
    file_lines = [_add_missing_period(line) for line in file_lines]

    # gather article lines
    story_lines = []
Rémi Louf's avatar
Rémi Louf committed
130
    lines = deque(file_lines)
131
132
133
134
135
136
137
138
139
140
    while True:
        try:
            element = lines.popleft()
            if element.startswith("@highlight"):
                break
            story_lines.append(element)
        except IndexError as ie:  # if "@highlight" absent from file
            raise ie

    # gather summary lines
Rémi Louf's avatar
Rémi Louf committed
141
    highlights_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
142
143
144
145
146
147
148
149
150

    # join the lines
    story = " ".join(story_lines)
    summary = " ".join(highlights_lines)

    return story, summary


def _add_missing_period(line):
151
    END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
Rémi Louf's avatar
Rémi Louf committed
152
    if line.startswith("@highlight"):
153
154
155
        return line
    if line[-1] in END_TOKENS:
        return line
Rémi Louf's avatar
Rémi Louf committed
156
    return line + "."
157
158


159
def _fit_to_block_size(sequence, block_size):
160
    """ Adapt the source and target sequences' lengths to the block size.
161
162
    If the sequence is shorter than the block size we pad it with -1 ids
    which correspond to padding tokens.
163
    """
164
165
    if len(sequence) > block_size:
        return sequence[:block_size]
166
    else:
167
        return sequence.extend([-1] * [block_size - len(sequence)])
168
169


170
171
def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt):
    dataset = TextDataset(tokenizer_src, tokenizer_tgt, file_path=args.data_dir)
172
    return dataset
Rémi Louf's avatar
Rémi Louf committed
173
174


175
176
177
178
179
# ------------
# Train
# ------------


180
181
def train(args, train_dataset, model, tokenizer):
    """ Fine-tune the pretrained model on the corpus. """
Rémi Louf's avatar
Rémi Louf committed
182
183
184
185
    raise NotImplementedError


def main():
186
187
188
    parser = argparse.ArgumentParser()

    # Required parameters
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input training data file (a text file).",
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
203
204

    # Optional parameters
205
    parser.add_argument(
206
        "--decoder_name_or_path",
207
208
        default="bert-base-cased",
        type=str,
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        help="The model checkpoint to initialize the decoder's weights with.",
    )
    parser.add_argument(
        "--decoder_type",
        default="bert",
        type=str,
        help="The decoder architecture to be fine-tuned.",
    )
    parser.add_argument(
        "--encoder_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The model checkpoint to initialize the encoder's weights with.",
    )
    parser.add_argument(
        "--encoder_type",
        default="bert",
        type=str,
        help="The encoder architecture to be fine-tuned.",
228
    )
229
230
231
    parser.add_argument("--seed", default=42, type=int)
    args = parser.parse_args()

232
233
234
    if args.encoder_type != 'bert' or args.decoder_type != 'bert':
        raise ValueError("Only the BERT architecture is currently supported for seq2seq.")

235
    # Set up training device
236
    # device = torch.device("cpu")
237
238
239
240
241

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
242
243
244
    encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path)
    decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path)
    model = Model2Model.from_pretrained(args.encoder_name_or_path, args.decoder_name_or_path)
245
    # model.to(device)
246
247
248
249

    logger.info("Training/evaluation parameters %s", args)

    # Training
250
    source, target = load_and_cache_examples(args, tokenizer)
251
252
    # global_step, tr_loss = train(args, train_dataset, model, tokenizer)
    # logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
Rémi Louf's avatar
Rémi Louf committed
253
254


255
if __name__ == "__main__":
256
    main()