run_translation.py 29.4 KB
Newer Older
1
#!/usr/bin/env python
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.

import logging
import os
import sys
24
import warnings
25
26
27
from dataclasses import dataclass, field
from typing import Optional

28
import datasets
29
import evaluate
30
import numpy as np
31
from datasets import load_dataset
32
33
34
35
36
37
38
39

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
40
41
42
    M2M100Tokenizer,
    MBart50Tokenizer,
    MBart50TokenizerFast,
43
    MBartTokenizer,
44
    MBartTokenizerFast,
45
46
47
48
49
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator,
    set_seed,
)
50
from transformers.trainer_utils import get_last_checkpoint
51
from transformers.utils import check_min_version, send_example_telemetry
52
from transformers.utils.versions import require_version
53
54


55
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Lysandre's avatar
Lysandre committed
56
check_min_version("4.36.0.dev0")
Sylvain Gugger's avatar
Sylvain Gugger committed
57

58
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
59

60
61
logger = logging.getLogger(__name__)

62
63
64
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
93
94
    token: str = field(
        default=None,
95
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
96
            "help": (
97
98
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
Sylvain Gugger's avatar
Sylvain Gugger committed
99
            )
100
101
        },
    )
102
103
104
    use_auth_token: bool = field(
        default=None,
        metadata={
105
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
106
107
        },
    )
108
109
110
111
112
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
113
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
114
115
116
117
                "execute code present on the Hub on your local machine."
            )
        },
    )
118
119
120
121
122
123
124
125


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

126
127
128
    source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
    target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})

129
130
131
132
133
134
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
135
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
136
137
    validation_file: Optional[str] = field(
        default=None,
138
        metadata={
139
            "help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file."
140
141
142
143
        },
    )
    test_file: Optional[str] = field(
        default=None,
144
        metadata={"help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file."},
145
146
147
148
149
150
151
152
153
154
155
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    max_source_length: Optional[int] = field(
        default=1024,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
156
157
158
159
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
160
161
162
163
164
        },
    )
    max_target_length: Optional[int] = field(
        default=128,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
165
166
167
168
            "help": (
                "The maximum total sequence length for target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
169
170
171
        },
    )
    val_max_target_length: Optional[int] = field(
172
        default=None,
173
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
174
175
            "help": (
                "The maximum total sequence length for validation target text after tokenization. Sequences longer "
176
                "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`. "
Sylvain Gugger's avatar
Sylvain Gugger committed
177
178
179
                "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
                "during ``evaluate`` and ``predict``."
            )
180
181
182
183
184
        },
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
185
186
187
188
189
            "help": (
                "Whether to pad all samples to model maximum sentence length. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
                "efficient on GPU but very bad for TPU."
            )
190
191
192
193
194
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
195
196
197
198
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
199
200
        },
    )
201
    max_eval_samples: Optional[int] = field(
202
203
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
204
205
206
207
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
208
209
        },
    )
210
    max_predict_samples: Optional[int] = field(
211
212
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
213
214
215
216
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
217
218
219
        },
    )
    num_beams: Optional[int] = field(
220
        default=1,
221
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
222
223
224
225
            "help": (
                "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
                "which is used during ``evaluate`` and ``predict``."
            )
226
227
        },
    )
228
229
230
231
232
233
    ignore_pad_token_for_loss: bool = field(
        default=True,
        metadata={
            "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
        },
    )
234
235
236
    source_prefix: Optional[str] = field(
        default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
    )
237
238
239
    forced_bos_token: Optional[str] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
240
241
242
243
244
            "help": (
                "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
                " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
                " be the target language token.(Usually it is the target language token)"
            )
245
246
        },
    )
247
248
249
250

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
251
252
253
        elif self.source_lang is None or self.target_lang is None:
            raise ValueError("Need to specify the source language and the target language.")

254
255
256
257
        # accepting both json and jsonl file extensions, as
        # many jsonlines files actually have a .json extension
        valid_extensions = ["json", "jsonl"]

258
259
        if self.train_file is not None:
            extension = self.train_file.split(".")[-1]
260
            assert extension in valid_extensions, "`train_file` should be a jsonlines file."
261
262
        if self.validation_file is not None:
            extension = self.validation_file.split(".")[-1]
263
            assert extension in valid_extensions, "`validation_file` should be a jsonlines file."
264
265
        if self.val_max_target_length is None:
            self.val_max_target_length = self.max_target_length
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

281
    if model_args.use_auth_token is not None:
282
283
284
285
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
            FutureWarning,
        )
286
287
288
289
        if model_args.token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        model_args.token = model_args.use_auth_token

290
291
292
293
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_translation", model_args, data_args)

294
295
    # Setup logging
    logging.basicConfig(
296
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
297
298
299
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
300

301
302
303
304
    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

305
    log_level = training_args.get_process_log_level()
306
307
308
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
309
310
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
311
312
313

    # Log on each process the small summary:
    logger.warning(
314
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
315
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
316
317
318
    )
    logger.info(f"Training/evaluation parameters {training_args}")

319
320
321
322
323
324
325
326
327
328
329
330
    if data_args.source_prefix is None and model_args.model_name_or_path in [
        "t5-small",
        "t5-base",
        "t5-large",
        "t5-3b",
        "t5-11b",
    ]:
        logger.warning(
            "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
            "`--source_prefix 'translate English to German: ' `"
        )

331
332
333
334
335
336
337
338
339
    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
340
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
341
342
343
344
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )
345
346
347
348

    # Set seed before initializing model.
    set_seed(training_args.seed)

349
    # Get the datasets: you can either provide your own JSON training and evaluation files (see below)
350
351
352
353
354
355
356
357
358
359
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For translation, only JSON files are supported, with one field named "translation" containing two keys for the
    # source and target languages (unless you adapt what follows).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
360
        raw_datasets = load_dataset(
361
362
363
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
364
            token=model_args.token,
365
        )
366
367
368
369
370
371
372
373
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
            extension = data_args.train_file.split(".")[-1]
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
            extension = data_args.validation_file.split(".")[-1]
374
375
376
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
            extension = data_args.test_file.split(".")[-1]
377
378
379
380
        raw_datasets = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
381
            token=model_args.token,
382
        )
383
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
384
    # https://huggingface.co/docs/datasets/loading.
385
386
387
388
389
390
391
392
393
394

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
395
        token=model_args.token,
396
        trust_remote_code=model_args.trust_remote_code,
397
398
399
400
401
402
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
403
        token=model_args.token,
404
        trust_remote_code=model_args.trust_remote_code,
405
406
407
408
409
410
411
    )
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
412
        token=model_args.token,
413
        trust_remote_code=model_args.trust_remote_code,
414
415
    )

416
417
418
419
420
    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tokenizer))
Suraj Patil's avatar
Suraj Patil committed
421

422
    # Set decoder_start_token_id
423
424
425
426
427
428
    if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
        if isinstance(tokenizer, MBartTokenizer):
            model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
        else:
            model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)

429
430
431
    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

432
    prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
433

434
435
436
    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    if training_args.do_train:
437
        column_names = raw_datasets["train"].column_names
438
    elif training_args.do_eval:
439
        column_names = raw_datasets["validation"].column_names
440
    elif training_args.do_predict:
441
        column_names = raw_datasets["test"].column_names
442
443
444
    else:
        logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
        return
445
446
447

    # For translation we set the codes of our source and target languages (only useful for mBART, the others will
    # ignore those attributes).
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
        assert data_args.target_lang is not None and data_args.source_lang is not None, (
            f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and "
            "--target_lang arguments."
        )

        tokenizer.src_lang = data_args.source_lang
        tokenizer.tgt_lang = data_args.target_lang

        # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
        # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
        forced_bos_token_id = (
            tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
        )
462
        model.config.forced_bos_token_id = forced_bos_token_id
463

464
465
466
    # Get the language codes for input/target.
    source_lang = data_args.source_lang.split("_")[0]
    target_lang = data_args.target_lang.split("_")[0]
467
468
469
470
471

    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length
    padding = "max_length" if data_args.pad_to_max_length else False

472
    if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
473
        logger.warning(
474
            "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for "
475
476
477
            f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
        )

478
    def preprocess_function(examples):
479
480
        inputs = [ex[source_lang] for ex in examples["translation"]]
        targets = [ex[target_lang] for ex in examples["translation"]]
481
        inputs = [prefix + inp for inp in inputs]
482
483
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)

484
485
        # Tokenize targets with the `text_target` keyword argument
        labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
486
487
488
489
490
491
492
493
494
495
496
497

        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    if training_args.do_train:
498
        if "train" not in raw_datasets:
499
            raise ValueError("--do_train requires a train dataset")
500
        train_dataset = raw_datasets["train"]
501
        if data_args.max_train_samples is not None:
502
503
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
504
505
506
507
508
509
510
511
512
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
513
514
515

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
516
        if "validation" not in raw_datasets:
517
            raise ValueError("--do_eval requires a validation dataset")
518
        eval_dataset = raw_datasets["validation"]
519
        if data_args.max_eval_samples is not None:
520
521
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
522
523
524
525
526
527
528
529
530
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
531

532
533
    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
534
        if "test" not in raw_datasets:
535
            raise ValueError("--do_predict requires a test dataset")
536
        predict_dataset = raw_datasets["test"]
537
        if data_args.max_predict_samples is not None:
538
539
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
540
541
542
543
544
545
546
547
548
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )
549

550
551
552
553
554
    # Data collator
    label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    else:
555
556
        data_collator = DataCollatorForSeq2Seq(
            tokenizer,
557
            model=model,
558
559
560
            label_pad_token_id=label_pad_token_id,
            pad_to_multiple_of=8 if training_args.fp16 else None,
        )
561
562

    # Metric
563
    metric = evaluate.load("sacrebleu")
564

565
566
    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
567
        labels = [[label.strip()] for label in labels]
568
569
570

        return preds, labels

571
572
573
574
    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
575
576
        # Replace -100s used for padding as we can't decode them
        preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
577
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
578
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
579
580
581
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
582
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
583

584
585
        result = metric.compute(predictions=decoded_preds, references=decoded_labels)
        result = {"bleu": result["score"]}
586
587
588

        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        result["gen_len"] = np.mean(prediction_lens)
589
        result = {k: round(v, 4) for k, v in result.items()}
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        return result

    # Initialize our Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
    )

    # Training
    if training_args.do_train:
605
606
607
608
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
609
610
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
611
612
        trainer.save_model()  # Saves the tokenizer too for easy upload

613
614
615
616
617
        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
618

619
620
621
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
622
623

    # Evaluation
624
    results = {}
625
626
627
628
629
630
    max_length = (
        training_args.generation_max_length
        if training_args.generation_max_length is not None
        else data_args.val_max_target_length
    )
    num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
631
632
633
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

634
        metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
635
636
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
637

638
639
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
640

641
    if training_args.do_predict:
642
        logger.info("*** Predict ***")
643

644
        predict_results = trainer.predict(
645
            predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
646
        )
647
648
649
650
651
        metrics = predict_results.metrics
        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
652

653
654
        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)
655

656
        if trainer.is_world_process_zero():
657
            if training_args.predict_with_generate:
658
659
                predictions = predict_results.predictions
                predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
660
                predictions = tokenizer.batch_decode(
661
                    predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
662
                )
663
664
                predictions = [pred.strip() for pred in predictions]
                output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
665
                with open(output_prediction_file, "w", encoding="utf-8") as writer:
666
                    writer.write("\n".join(predictions))
667

668
669
670
671
672
673
674
675
    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "translation"}
    if data_args.dataset_name is not None:
        kwargs["dataset_tags"] = data_args.dataset_name
        if data_args.dataset_config_name is not None:
            kwargs["dataset_args"] = data_args.dataset_config_name
            kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
        else:
            kwargs["dataset"] = data_args.dataset_name
Sylvain Gugger's avatar
Sylvain Gugger committed
676

677
678
679
680
681
    languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
    if len(languages) > 0:
        kwargs["language"] = languages

    if training_args.push_to_hub:
Sylvain Gugger's avatar
Sylvain Gugger committed
682
        trainer.push_to_hub(**kwargs)
683
684
    else:
        trainer.create_model_card(**kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
685

686
687
    return results

688
689
690
691
692
693
694
695

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()