run_summarization.py 32.4 KB
Newer Older
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2021 The HuggingFace Team. All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.

import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

27
import datasets
28
import evaluate
29
import nltk  # Here to have a nice missing dependency error message early on
30
import numpy as np
31
from datasets import load_dataset
32
from filelock import FileLock
33
34
35
36
37
38
39
40

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
41
42
43
44
    MBart50Tokenizer,
    MBart50TokenizerFast,
    MBartTokenizer,
    MBartTokenizerFast,
45
46
47
48
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed,
)
49
from transformers.trainer_utils import get_last_checkpoint
50
from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry
51
from transformers.utils.versions import require_version
52
53


54
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Arthur Zucker's avatar
Arthur Zucker committed
55
check_min_version("4.42.0.dev0")
Sylvain Gugger's avatar
Sylvain Gugger committed
56

57
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
58

59
60
logger = logging.getLogger(__name__)

61
62
try:
    nltk.data.find("tokenizers/punkt")
Stas Bekman's avatar
Stas Bekman committed
63
except (LookupError, OSError):
64
65
66
67
68
69
70
    if is_offline_mode():
        raise LookupError(
            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
        )
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)

71
72
73
# A list of all multilingual tokenizer which require lang attribute.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
102
103
    token: str = field(
        default=None,
104
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
105
            "help": (
106
107
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
Sylvain Gugger's avatar
Sylvain Gugger committed
108
            )
109
110
        },
    )
111
112
113
114
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
115
116
117
                "Whether to trust the execution of code from datasets/models defined on the Hub."
                " This option should only be set to `True` for repositories you trust and in which you have read the"
                " code, as it will execute code present on the Hub on your local machine."
118
119
120
            )
        },
    )
121
122
123
    resize_position_embeddings: Optional[bool] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
124
125
126
127
            "help": (
                "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
                "the model's position embeddings."
            )
128
129
        },
    )
130
131
132
133
134
135
136
137


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

138
    lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    text_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
    )
    summary_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
    )
154
155
156
    train_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
    )
157
158
    validation_file: Optional[str] = field(
        default=None,
159
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
160
161
162
            "help": (
                "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
            )
163
164
165
166
167
        },
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
168
            "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
169
        },
170
171
172
173
174
175
176
177
178
179
180
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    max_source_length: Optional[int] = field(
        default=1024,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
181
182
183
184
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
185
186
187
188
189
        },
    )
    max_target_length: Optional[int] = field(
        default=128,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
190
191
192
193
            "help": (
                "The maximum total sequence length for target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
194
195
196
        },
    )
    val_max_target_length: Optional[int] = field(
197
        default=None,
198
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
199
200
            "help": (
                "The maximum total sequence length for validation target text after tokenization. Sequences longer "
201
                "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`. "
Sylvain Gugger's avatar
Sylvain Gugger committed
202
203
204
                "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
                "during ``evaluate`` and ``predict``."
            )
205
206
207
208
209
        },
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
210
211
212
213
214
            "help": (
                "Whether to pad all samples to model maximum sentence length. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
                "efficient on GPU but very bad for TPU."
            )
215
216
217
218
219
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
220
221
222
223
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
224
225
        },
    )
226
    max_eval_samples: Optional[int] = field(
227
228
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
229
230
231
232
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
233
234
        },
    )
235
    max_predict_samples: Optional[int] = field(
236
237
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
238
239
240
241
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
242
243
244
        },
    )
    num_beams: Optional[int] = field(
245
        default=1,
246
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
247
248
249
250
            "help": (
                "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
                "which is used during ``evaluate`` and ``predict``."
            )
251
252
        },
    )
253
254
255
256
257
258
    ignore_pad_token_for_loss: bool = field(
        default=True,
        metadata={
            "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
        },
    )
259
    source_prefix: Optional[str] = field(
260
        default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
261
262
263
264
265
    )

    forced_bos_token: Optional[str] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
266
            "help": (
267
                "The token to force as the first generated token after the decoder_start_token_id. "
Sylvain Gugger's avatar
Sylvain Gugger committed
268
269
270
                "Useful for multilingual models like mBART where the first generated token"
                "needs to be the target language token (Usually it is the target language token)"
            )
271
        },
272
    )
273
274

    def __post_init__(self):
275
276
277
278
279
280
281
        if (
            self.dataset_name is None
            and self.train_file is None
            and self.validation_file is None
            and self.test_file is None
        ):
            raise ValueError("Need either a dataset name or a training, validation, or test file.")
282
283
284
285
286
287
288
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
289
290
291
            if self.test_file is not None:
                extension = self.test_file.split(".")[-1]
                assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
292
293
        if self.val_max_target_length is None:
            self.val_max_target_length = self.max_target_length
294
295
296


summarization_name_mapping = {
297
298
    "amazon_reviews_multi": ("review_body", "review_title"),
    "big_patent": ("description", "abstract"),
299
    "cnn_dailymail": ("article", "highlights"),
300
301
302
303
304
305
    "orange_sum": ("text", "summary"),
    "pn_summary": ("article", "summary"),
    "psc": ("extract_text", "summary_text"),
    "samsum": ("dialogue", "summary"),
    "thaisum": ("body", "summary"),
    "xglue": ("news_body", "news_title"),
306
    "xsum": ("document", "summary"),
307
    "wiki_summary": ("article", "highlights"),
308
    "multi_news": ("document", "summary"),
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
}


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

325
326
327
328
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_summarization", model_args, data_args)

329
330
    # Setup logging
    logging.basicConfig(
331
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
332
333
334
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
335
336
337
338
339

    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

340
341
342
343
344
345
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
346
347
348

    # Log on each process the small summary:
    logger.warning(
349
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
350
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
351
352
353
    )
    logger.info(f"Training/evaluation parameters {training_args}")

354
    if data_args.source_prefix is None and model_args.model_name_or_path in [
355
356
357
358
359
        "google-t5/t5-small",
        "google-t5/t5-base",
        "google-t5/t5-large",
        "google-t5/t5-3b",
        "google-t5/t5-11b",
360
361
362
363
364
365
    ]:
        logger.warning(
            "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
            "`--source_prefix 'summarize: ' `"
        )

366
367
368
369
370
371
372
373
374
    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
375
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
376
377
378
379
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )
380
381
382
383
384
385
386
387

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
388
389
    # For CSV/JSON files this script will use the first column for the full texts and the second column for the
    # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
390
391
392
393
394
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
395
        raw_datasets = load_dataset(
396
397
398
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
399
            token=model_args.token,
400
            trust_remote_code=model_args.trust_remote_code,
401
        )
402
403
404
405
406
407
408
409
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
            extension = data_args.train_file.split(".")[-1]
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
            extension = data_args.validation_file.split(".")[-1]
410
411
412
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
            extension = data_args.test_file.split(".")[-1]
413
414
415
416
        raw_datasets = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
417
            token=model_args.token,
418
        )
419
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
420
    # https://huggingface.co/docs/datasets/loading_datasets.
421
422
423
424
425
426
427
428
429
430

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
431
        token=model_args.token,
432
        trust_remote_code=model_args.trust_remote_code,
433
434
435
436
437
438
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
439
        token=model_args.token,
440
        trust_remote_code=model_args.trust_remote_code,
441
442
443
444
445
446
447
    )
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
448
        token=model_args.token,
449
        trust_remote_code=model_args.trust_remote_code,
450
451
    )

452
453
454
455
456
    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tokenizer))
Suraj Patil's avatar
Suraj Patil committed
457

458
459
460
461
462
463
    if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
        if isinstance(tokenizer, MBartTokenizer):
            model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
        else:
            model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)

464
465
466
    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

467
468
469
470
471
472
    if (
        hasattr(model.config, "max_position_embeddings")
        and model.config.max_position_embeddings < data_args.max_source_length
    ):
        if model_args.resize_position_embeddings is None:
            logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
473
474
                "Increasing the model's number of position embedding vectors from"
                f" {model.config.max_position_embeddings} to {data_args.max_source_length}."
475
476
477
478
479
480
            )
            model.resize_position_embeddings(data_args.max_source_length)
        elif model_args.resize_position_embeddings:
            model.resize_position_embeddings(data_args.max_source_length)
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
481
482
483
484
                f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
                f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
                f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the"
                " model's position encodings by passing `--resize_position_embeddings`."
485
486
            )

487
    prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
488

489
490
491
    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    if training_args.do_train:
492
493
        if "train" not in raw_datasets:
            raise ValueError("--do_train requires a train dataset")
494
        column_names = raw_datasets["train"].column_names
495
    elif training_args.do_eval:
496
497
        if "validation" not in raw_datasets:
            raise ValueError("--do_eval requires a validation dataset")
498
        column_names = raw_datasets["validation"].column_names
499
    elif training_args.do_predict:
500
501
        if "test" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
502
        column_names = raw_datasets["test"].column_names
503
504
505
    else:
        logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
        return
506

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
        assert (
            data_args.lang is not None
        ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"

        tokenizer.src_lang = data_args.lang
        tokenizer.tgt_lang = data_args.lang

        # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
        # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
        forced_bos_token_id = (
            tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
        )
        model.config.forced_bos_token_id = forced_bos_token_id

522
523
524
525
    # Get the column names for input/target.
    dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
    if data_args.text_column is None:
        text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
526
    else:
527
528
529
530
531
532
533
534
535
536
537
538
539
        text_column = data_args.text_column
        if text_column not in column_names:
            raise ValueError(
                f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
            )
    if data_args.summary_column is None:
        summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        summary_column = data_args.summary_column
        if summary_column not in column_names:
            raise ValueError(
                f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
            )
540
541
542
543
544

    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length
    padding = "max_length" if data_args.pad_to_max_length else False

545
    if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
546
        logger.warning(
547
            "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for "
548
549
550
            f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
        )

551
    def preprocess_function(examples):
552
        # remove pairs where at least one record is None
553

554
555
        inputs, targets = [], []
        for i in range(len(examples[text_column])):
556
            if examples[text_column][i] and examples[summary_column][i]:
557
558
559
                inputs.append(examples[text_column][i])
                targets.append(examples[summary_column][i])

560
        inputs = [prefix + inp for inp in inputs]
561
562
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)

563
564
        # Tokenize targets with the `text_target` keyword argument
        labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
565
566
567
568
569
570
571
572
573
574
575
576

        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    if training_args.do_train:
577
        train_dataset = raw_datasets["train"]
578
        if data_args.max_train_samples is not None:
579
580
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
581
582
583
584
585
586
587
588
589
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
590
591
592

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
593
        eval_dataset = raw_datasets["validation"]
594
        if data_args.max_eval_samples is not None:
595
596
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
597
598
599
600
601
602
603
604
605
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
606

607
608
    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
609
        predict_dataset = raw_datasets["test"]
610
        if data_args.max_predict_samples is not None:
611
612
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
613
614
615
616
617
618
619
620
621
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )
622

623
624
    # Data collator
    label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
625
626
627
628
629
630
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of=8 if training_args.fp16 else None,
    )
631
632

    # Metric
633
    metric = evaluate.load("rouge", cache_dir=model_args.cache_dir)
634

635
636
637
638
639
    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
640
641
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
642
643
644

        return preds, labels

645
646
647
648
    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
649
650
        # Replace -100s used for padding as we can't decode them
        preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
651
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
652
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
653
654
655
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
656
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
657

658
        result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
659
        result = {k: round(v * 100, 4) for k, v in result.items()}
660
661
662
663
        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        result["gen_len"] = np.mean(prediction_lens)
        return result

664
    # Override the decoding parameters of Seq2SeqTrainer
665
666
667
668
669
670
671
672
    training_args.generation_max_length = (
        training_args.generation_max_length
        if training_args.generation_max_length is not None
        else data_args.val_max_target_length
    )
    training_args.generation_num_beams = (
        data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
    )
673

674
675
676
677
678
679
680
681
682
683
684
685
686
    # Initialize our Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
    )

    # Training
    if training_args.do_train:
687
688
689
690
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
691
692
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
693
694
        trainer.save_model()  # Saves the tokenizer too for easy upload

695
696
697
698
699
        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
700

701
702
703
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
704
705

    # Evaluation
706
    results = {}
707
708
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
709
710
711
712
713
714
715
        if isinstance(eval_dataset, dict):
            metrics = {}
            for eval_ds_name, eval_ds in eval_dataset.items():
                dataset_metrics = trainer.evaluate(eval_dataset=eval_ds, metric_key_prefix=f"eval_{eval_ds_name}")
                metrics.update(dataset_metrics)
        else:
            metrics = trainer.evaluate(metric_key_prefix="eval")
716
717
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
718

719
720
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
721

722
    if training_args.do_predict:
723
        logger.info("*** Predict ***")
724

725
        predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict")
726
727
728
729
730
        metrics = predict_results.metrics
        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
731

732
733
        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)
734

735
        if trainer.is_world_process_zero():
736
            if training_args.predict_with_generate:
737
738
                predictions = predict_results.predictions
                predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
739
                predictions = tokenizer.batch_decode(
740
                    predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
741
                )
742
743
744
745
                predictions = [pred.strip() for pred in predictions]
                output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
                with open(output_prediction_file, "w") as writer:
                    writer.write("\n".join(predictions))
746

747
748
749
750
751
752
753
754
    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"}
    if data_args.dataset_name is not None:
        kwargs["dataset_tags"] = data_args.dataset_name
        if data_args.dataset_config_name is not None:
            kwargs["dataset_args"] = data_args.dataset_config_name
            kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
        else:
            kwargs["dataset"] = data_args.dataset_name
Sylvain Gugger's avatar
Sylvain Gugger committed
755

756
757
758
    if data_args.lang is not None:
        kwargs["language"] = data_args.lang

759
    if training_args.push_to_hub:
Sylvain Gugger's avatar
Sylvain Gugger committed
760
        trainer.push_to_hub(**kwargs)
761
762
    else:
        trainer.create_model_card(**kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
763

764
765
    return results

766
767
768
769
770
771
772
773

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()