run_seq2seq_finetuning.py 11.5 KB
Newer Older
Rémi Louf's avatar
Rémi Louf committed
1
# coding=utf-8
2
# Copyright 2018 The Microsoft Reseach team and The HuggingFace Inc. team.
Rémi Louf's avatar
Rémi Louf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2018 Microsoft and The HuggingFace Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Rémi Louf's avatar
Rémi Louf committed
16
""" Finetuning seq2seq models for sequence generation."""
Rémi Louf's avatar
Rémi Louf committed
17

18
import argparse
Rémi Louf's avatar
Rémi Louf committed
19
from collections import deque
Rémi Louf's avatar
Rémi Louf committed
20
import logging
21
import pickle
Rémi Louf's avatar
Rémi Louf committed
22
import random
23
import os
Rémi Louf's avatar
Rémi Louf committed
24
25

import numpy as np
26
from tqdm import tqdm, trange
Rémi Louf's avatar
Rémi Louf committed
27
import torch
28
from torch.utils.data import Dataset, RandomSampler
Rémi Louf's avatar
Rémi Louf committed
29

30
from transformers import AutoTokenizer, Model2Model
31

Rémi Louf's avatar
Rémi Louf committed
32
33
34
35
36
37
38
39
40
logger = logging.getLogger(__name__)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


41
42
43
44
# ------------
# Load dataset
# ------------

Rémi Louf's avatar
Rémi Louf committed
45

46
class TextDataset(Dataset):
47
    """ Abstracts the dataset used to train seq2seq models.
48
49
50

    CNN/Daily News:

51
52
53
54
    The CNN/Daily News raw datasets are downloaded from [1]. The stories are
    stored in different files; the summary appears at the end of the story as
    sentences that are prefixed by the special `@highlight` line. To process
    the data, untar both datasets in the same folder, and pass the path to this
55
    folder as the "data_dir argument. The formatting code was inspired by [2].
56

57
58
    [1] https://cs.nyu.edu/~kcho/
    [2] https://github.com/abisee/cnn-dailymail/
59
    """
60

Rémi Louf's avatar
Rémi Louf committed
61
    def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
62
        assert os.path.isdir(data_dir)
63

64
        # Load features that have already been computed if present
65
        cached_features_file = os.path.join(
thomwolf's avatar
thomwolf committed
66
            data_dir, "cached_lm_{}_{}".format(block_size, prefix)
67
        )
68
69
70
71
        if os.path.exists(cached_features_file):
            logger.info("Loading features from cached file %s", cached_features_file)
            with open(cached_features_file, "rb") as source:
                self.examples = pickle.load(source)
72
73
                return

74
        logger.info("Creating features from dataset at %s", data_dir)
thomwolf's avatar
thomwolf committed
75
        self.examples = []
76
        datasets = ["cnn", "dailymail"]
77
78
79
80
        for dataset in datasets:
            path_to_stories = os.path.join(data_dir, dataset, "stories")
            assert os.path.isdir(path_to_stories)

Rémi Louf's avatar
Rémi Louf committed
81
82
83
            story_filenames_list = os.listdir(path_to_stories)
            for story_filename in story_filenames_list:
                path_to_story = os.path.join(path_to_stories, story_filename)
Rémi Louf's avatar
Rémi Louf committed
84
                if not os.path.isfile(path_to_story):
85
86
87
88
                    continue

                with open(path_to_story, encoding="utf-8") as source:
                    try:
Rémi Louf's avatar
Rémi Louf committed
89
90
                        raw_story = source.read()
                        story, summary = process_story(raw_story)
Rémi Louf's avatar
Rémi Louf committed
91
                    except IndexError:  # skip ill-formed stories
92
93
                        continue

thomwolf's avatar
thomwolf committed
94
                story = tokenizer.encode(story)
95
                story_seq = _fit_to_block_size(story, block_size)
Rémi Louf's avatar
Rémi Louf committed
96

thomwolf's avatar
thomwolf committed
97
                summary = tokenizer.encode(summary)
98
99
100
                summary_seq = _fit_to_block_size(summary, block_size)

                self.examples.append((story_seq, summary_seq))
101
102
103

        logger.info("Saving features into cache file %s", cached_features_file)
        with open(cached_features_file, "wb") as sink:
thomwolf's avatar
thomwolf committed
104
            pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
105
106
107
108

    def __len__(self):
        return len(self.examples)

Rémi Louf's avatar
Rémi Louf committed
109
    def __getitem__(self, items):
110
111
112
        return torch.tensor(self.examples[items])


Rémi Louf's avatar
Rémi Louf committed
113
def process_story(raw_story):
Rémi Louf's avatar
Rémi Louf committed
114
115
116
117
118
119
120
    """ Extract the story and summary from a story file.

    Attributes:
        raw_story (str): content of the story file as an utf-8 encoded string.

    Raises:
        IndexError: If the stoy is empty or contains no highlights.
121
    """
122
123
124
    file_lines = list(
        filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
    )
125
126
127
128
129
130

    # for some unknown reason some lines miss a period, add it
    file_lines = [_add_missing_period(line) for line in file_lines]

    # gather article lines
    story_lines = []
Rémi Louf's avatar
Rémi Louf committed
131
    lines = deque(file_lines)
132
133
134
135
136
137
138
139
140
141
    while True:
        try:
            element = lines.popleft()
            if element.startswith("@highlight"):
                break
            story_lines.append(element)
        except IndexError as ie:  # if "@highlight" absent from file
            raise ie

    # gather summary lines
Rémi Louf's avatar
Rémi Louf committed
142
    highlights_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
143
144
145
146
147
148
149
150
151

    # join the lines
    story = " ".join(story_lines)
    summary = " ".join(highlights_lines)

    return story, summary


def _add_missing_period(line):
152
    END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
Rémi Louf's avatar
Rémi Louf committed
153
    if line.startswith("@highlight"):
154
155
156
        return line
    if line[-1] in END_TOKENS:
        return line
Rémi Louf's avatar
Rémi Louf committed
157
    return line + "."
158
159


160
def _fit_to_block_size(sequence, block_size):
161
    """ Adapt the source and target sequences' lengths to the block size.
162
163
    If the sequence is shorter than the block size we pad it with -1 ids
    which correspond to padding tokens.
164
    """
165
166
    if len(sequence) > block_size:
        return sequence[:block_size]
167
    else:
Rémi Louf's avatar
Rémi Louf committed
168
169
        sequence.extend([0] * (block_size - len(sequence)))
        return sequence
Rémi Louf's avatar
Rémi Louf committed
170
171
172
173
174


def mask_padding_tokens(sequence):
    """ Replace the padding token with -1 values """
    return [s if s != 0 else -1 for s in sequence]
175
176


thomwolf's avatar
thomwolf committed
177
178
def load_and_cache_examples(args, tokenizer):
    dataset = TextDataset(tokenizer, data_dir=args.data_dir)
179
    return dataset
Rémi Louf's avatar
Rémi Louf committed
180
181


182
183
184
185
186
# ------------
# Train
# ------------


187
188
def train(args, train_dataset, model, tokenizer):
    """ Fine-tune the pretrained model on the corpus. """
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

    # Prepare the data loading
    args.train_bach_size = 1
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_bach_size
    )

    # Prepare the optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(
        optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon
    )
    scheduler = WarmupLinearSchedule(
        optimizer, warmup_steps=args.warmup_steps, t_total=t_total
    )

    # Train
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
Rémi Louf's avatar
Rémi Louf committed
228
229
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True)
    set_seed(args)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
        for step, batch in enumerate(epoch_iterator):
            source = ([s for s, _ in batch]).to(args.device)
            target = ([t for _, t in batch]).to(args.device)
            model.train()
Rémi Louf's avatar
Rémi Louf committed
248
            outputs = model(source, target, decoder_lm_labels=mask_padding_tokens(target))
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
            loss = outputs[0]
            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    return global_step, tr_loss / global_step
Rémi Louf's avatar
Rémi Louf committed
269
270
271


def main():
272
273
274
    parser = argparse.ArgumentParser()

    # Required parameters
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input training data file (a text file).",
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
289
290

    # Optional parameters
291
292
293
    parser.add_argument(
        "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
    )
294
    parser.add_argument(
thomwolf's avatar
thomwolf committed
295
        "--model_name_or_path",
296
297
        default="bert-base-cased",
        type=str,
thomwolf's avatar
thomwolf committed
298
        help="The model checkpoint to initialize the encoder and decoder's weights with.",
299
300
    )
    parser.add_argument(
thomwolf's avatar
thomwolf committed
301
        "--model_type",
302
303
304
305
        default="bert",
        type=str,
        help="The decoder architecture to be fine-tuned.",
    )
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    parser.add_argument(
        "--learning_rate",
        default=5e-5,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=1,
        type=int,
        help="Total number of training epochs to perform.",
    )
327
    parser.add_argument("--seed", default=42, type=int)
328
329
330
331
332
333
    parser.add_argument(
        "--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps."
    )
    parser.add_argument(
        "--weight_decay", default=0.0, type=float, help="Weight deay if we apply some."
    )
334
335
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
336
    if args.model_type != "bert":
337
338
339
        raise ValueError(
            "Only the BERT architecture is currently supported for seq2seq."
        )
340

341
    # Set up training device
342
    # device = torch.device("cpu")
343
344
345
346
347

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
thomwolf's avatar
thomwolf committed
348
349
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    model = Model2Model.from_pretrained(args.model_name_or_path)
350
    # model.to(device)
351
352
353
354

    logger.info("Training/evaluation parameters %s", args)

    # Training
355
356
    train_dataset = load_and_cache_examples(args, tokenizer)
    global_step, tr_loss = train(args, train_dataset, model, tokenizer)
357
    # logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
Rémi Louf's avatar
Rémi Louf committed
358
359


360
if __name__ == "__main__":
361
    main()