run_translation.py 27.6 KB
Newer Older
1
#!/usr/bin/env python
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.

import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

27
import datasets
28
29
30
31
32
33
34
35
36
37
import numpy as np
from datasets import load_dataset, load_metric

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
38
39
40
    M2M100Tokenizer,
    MBart50Tokenizer,
    MBart50TokenizerFast,
41
    MBartTokenizer,
42
    MBartTokenizerFast,
43
44
45
46
47
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator,
    set_seed,
)
48
from transformers.trainer_utils import get_last_checkpoint
49
from transformers.utils import check_min_version, send_example_telemetry
50
from transformers.utils.versions import require_version
51
52


53
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Sylvain Gugger's avatar
Sylvain Gugger committed
54
check_min_version("4.21.0.dev0")
Sylvain Gugger's avatar
Sylvain Gugger committed
55

56
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
57

58
59
logger = logging.getLogger(__name__)

60
61
62
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
94
95
96
97
            "help": (
                "Will use the token generated when running `transformers-cli login` (necessary to use this script "
                "with private models)."
            )
98
99
100
101
102
103
104
105
106
107
        },
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

108
109
110
    source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
    target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})

111
112
113
114
115
116
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
117
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
118
119
    validation_file: Optional[str] = field(
        default=None,
120
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
121
            "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on a jsonlines file."
122
123
124
125
        },
    )
    test_file: Optional[str] = field(
        default=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
126
        metadata={"help": "An optional input test data file to evaluate the metrics (sacreblue) on a jsonlines file."},
127
128
129
130
131
132
133
134
135
136
137
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    max_source_length: Optional[int] = field(
        default=1024,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
138
139
140
141
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
142
143
144
145
146
        },
    )
    max_target_length: Optional[int] = field(
        default=128,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
147
148
149
150
            "help": (
                "The maximum total sequence length for target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
151
152
153
        },
    )
    val_max_target_length: Optional[int] = field(
154
        default=None,
155
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
156
157
158
159
160
161
            "help": (
                "The maximum total sequence length for validation target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
                "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
                "during ``evaluate`` and ``predict``."
            )
162
163
164
165
166
        },
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
167
168
169
170
171
            "help": (
                "Whether to pad all samples to model maximum sentence length. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
                "efficient on GPU but very bad for TPU."
            )
172
173
174
175
176
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
177
178
179
180
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
181
182
        },
    )
183
    max_eval_samples: Optional[int] = field(
184
185
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
186
187
188
189
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
190
191
        },
    )
192
    max_predict_samples: Optional[int] = field(
193
194
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
195
196
197
198
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
199
200
201
202
203
        },
    )
    num_beams: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
204
205
206
207
            "help": (
                "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
                "which is used during ``evaluate`` and ``predict``."
            )
208
209
        },
    )
210
211
212
213
214
215
    ignore_pad_token_for_loss: bool = field(
        default=True,
        metadata={
            "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
        },
    )
216
217
218
    source_prefix: Optional[str] = field(
        default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
    )
219
220
221
    forced_bos_token: Optional[str] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
222
223
224
225
226
            "help": (
                "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
                " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
                " be the target language token.(Usually it is the target language token)"
            )
227
228
        },
    )
229
230
231
232

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
233
234
235
        elif self.source_lang is None or self.target_lang is None:
            raise ValueError("Need to specify the source language and the target language.")

236
237
238
239
        # accepting both json and jsonl file extensions, as
        # many jsonlines files actually have a .json extension
        valid_extensions = ["json", "jsonl"]

240
241
        if self.train_file is not None:
            extension = self.train_file.split(".")[-1]
242
            assert extension in valid_extensions, "`train_file` should be a jsonlines file."
243
244
        if self.validation_file is not None:
            extension = self.validation_file.split(".")[-1]
245
            assert extension in valid_extensions, "`validation_file` should be a jsonlines file."
246
247
        if self.val_max_target_length is None:
            self.val_max_target_length = self.max_target_length
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

263
264
265
266
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_translation", model_args, data_args)

267
268
    # Setup logging
    logging.basicConfig(
269
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
270
271
272
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
273

274
    log_level = training_args.get_process_log_level()
275
276
277
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
278
279
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
280
281
282
283
284
285
286
287

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

288
289
290
291
292
293
294
295
296
297
298
299
    if data_args.source_prefix is None and model_args.model_name_or_path in [
        "t5-small",
        "t5-base",
        "t5-large",
        "t5-3b",
        "t5-11b",
    ]:
        logger.warning(
            "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
            "`--source_prefix 'translate English to German: ' `"
        )

300
301
302
303
304
305
306
307
308
    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
309
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
310
311
312
313
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )
314
315
316
317

    # Set seed before initializing model.
    set_seed(training_args.seed)

318
    # Get the datasets: you can either provide your own JSON training and evaluation files (see below)
319
320
321
322
323
324
325
326
327
328
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For translation, only JSON files are supported, with one field named "translation" containing two keys for the
    # source and target languages (unless you adapt what follows).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
329
        raw_datasets = load_dataset(
330
331
332
333
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
334
        )
335
336
337
338
339
340
341
342
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
            extension = data_args.train_file.split(".")[-1]
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
            extension = data_args.validation_file.split(".")[-1]
343
344
345
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
            extension = data_args.test_file.split(".")[-1]
346
347
348
349
350
351
        raw_datasets = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

Suraj Patil's avatar
Suraj Patil committed
382
383
    model.resize_token_embeddings(len(tokenizer))

384
    # Set decoder_start_token_id
385
386
387
388
389
390
    if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
        if isinstance(tokenizer, MBartTokenizer):
            model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
        else:
            model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)

391
392
393
    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

394
    prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
395

396
397
398
    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    if training_args.do_train:
399
        column_names = raw_datasets["train"].column_names
400
    elif training_args.do_eval:
401
        column_names = raw_datasets["validation"].column_names
402
    elif training_args.do_predict:
403
        column_names = raw_datasets["test"].column_names
404
405
406
    else:
        logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
        return
407
408
409

    # For translation we set the codes of our source and target languages (only useful for mBART, the others will
    # ignore those attributes).
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
        assert data_args.target_lang is not None and data_args.source_lang is not None, (
            f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and "
            "--target_lang arguments."
        )

        tokenizer.src_lang = data_args.source_lang
        tokenizer.tgt_lang = data_args.target_lang

        # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
        # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
        forced_bos_token_id = (
            tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
        )
424
        model.config.forced_bos_token_id = forced_bos_token_id
425

426
427
428
    # Get the language codes for input/target.
    source_lang = data_args.source_lang.split("_")[0]
    target_lang = data_args.target_lang.split("_")[0]
429
430
431
432
433

    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length
    padding = "max_length" if data_args.pad_to_max_length else False

434
    if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
435
        logger.warning(
436
437
438
439
            "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
            f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
        )

440
    def preprocess_function(examples):
441
442
        inputs = [ex[source_lang] for ex in examples["translation"]]
        targets = [ex[target_lang] for ex in examples["translation"]]
443
        inputs = [prefix + inp for inp in inputs]
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)

        # Setup the tokenizer for targets
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)

        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    if training_args.do_train:
461
        if "train" not in raw_datasets:
462
            raise ValueError("--do_train requires a train dataset")
463
        train_dataset = raw_datasets["train"]
464
        if data_args.max_train_samples is not None:
465
466
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
467
468
469
470
471
472
473
474
475
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
476
477
478

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
479
        if "validation" not in raw_datasets:
480
            raise ValueError("--do_eval requires a validation dataset")
481
        eval_dataset = raw_datasets["validation"]
482
        if data_args.max_eval_samples is not None:
483
484
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
485
486
487
488
489
490
491
492
493
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
494

495
496
    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
497
        if "test" not in raw_datasets:
498
            raise ValueError("--do_predict requires a test dataset")
499
        predict_dataset = raw_datasets["test"]
500
        if data_args.max_predict_samples is not None:
501
502
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
503
504
505
506
507
508
509
510
511
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )
512

513
514
515
516
517
    # Data collator
    label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    else:
518
519
        data_collator = DataCollatorForSeq2Seq(
            tokenizer,
520
            model=model,
521
522
523
            label_pad_token_id=label_pad_token_id,
            pad_to_multiple_of=8 if training_args.fp16 else None,
        )
524
525

    # Metric
526
    metric = load_metric("sacrebleu")
527

528
529
    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
530
        labels = [[label.strip()] for label in labels]
531
532
533

        return preds, labels

534
535
536
537
538
539
540
541
542
543
544
    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
545
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
546

547
548
        result = metric.compute(predictions=decoded_preds, references=decoded_labels)
        result = {"bleu": result["score"]}
549
550
551

        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        result["gen_len"] = np.mean(prediction_lens)
552
        result = {k: round(v, 4) for k, v in result.items()}
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        return result

    # Initialize our Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
    )

    # Training
    if training_args.do_train:
568
569
570
571
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
572
573
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
574
575
        trainer.save_model()  # Saves the tokenizer too for easy upload

576
577
578
579
580
        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
581

582
583
584
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
585
586

    # Evaluation
587
    results = {}
588
589
590
591
592
593
    max_length = (
        training_args.generation_max_length
        if training_args.generation_max_length is not None
        else data_args.val_max_target_length
    )
    num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
594
595
596
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

597
        metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
598
599
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
600

601
602
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
603

604
    if training_args.do_predict:
605
        logger.info("*** Predict ***")
606

607
        predict_results = trainer.predict(
608
            predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
609
        )
610
611
612
613
614
        metrics = predict_results.metrics
        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
615

616
617
        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)
618

619
        if trainer.is_world_process_zero():
620
            if training_args.predict_with_generate:
621
622
                predictions = tokenizer.batch_decode(
                    predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
623
                )
624
625
                predictions = [pred.strip() for pred in predictions]
                output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
626
                with open(output_prediction_file, "w", encoding="utf-8") as writer:
627
                    writer.write("\n".join(predictions))
628

629
630
631
632
633
634
635
636
    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "translation"}
    if data_args.dataset_name is not None:
        kwargs["dataset_tags"] = data_args.dataset_name
        if data_args.dataset_config_name is not None:
            kwargs["dataset_args"] = data_args.dataset_config_name
            kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
        else:
            kwargs["dataset"] = data_args.dataset_name
Sylvain Gugger's avatar
Sylvain Gugger committed
637

638
639
640
641
642
    languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
    if len(languages) > 0:
        kwargs["language"] = languages

    if training_args.push_to_hub:
Sylvain Gugger's avatar
Sylvain Gugger committed
643
        trainer.push_to_hub(**kwargs)
644
645
    else:
        trainer.create_model_card(**kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
646

647
648
    return results

649
650
651
652
653
654
655
656

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()