run_mlm.py 29.2 KB
Newer Older
1
#!/usr/bin/env python
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset.

Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
https://huggingface.co/models?filter=fill-mask
21
22
23
24
25
26
27
"""
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.

import logging
import math
import os
import sys
28
import warnings
29
from dataclasses import dataclass, field
30
from itertools import chain
31
32
from typing import Optional

33
import datasets
34
import evaluate
35
from datasets import load_dataset
36
37
38
39
40
41
42
43
44
45
46
47

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_MASKED_LM_MAPPING,
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
48
    is_torch_tpu_available,
49
50
    set_seed,
)
51
from transformers.trainer_utils import get_last_checkpoint
52
from transformers.utils import check_min_version, send_example_telemetry
53
from transformers.utils.versions import require_version
54
55


56
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Lysandre's avatar
Lysandre committed
57
check_min_version("4.36.0.dev0")
Sylvain Gugger's avatar
Sylvain Gugger committed
58

59
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
75
            "help": (
76
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
Sylvain Gugger's avatar
Sylvain Gugger committed
77
            )
78
79
80
81
82
83
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
84
85
86
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
87
88
89
90
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
91
92
        },
    )
93
94
95
96
97
98
99
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
100
101
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
102
103
104
105
106
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
107
108
109
110
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
111
112
    token: str = field(
        default=None,
113
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
114
            "help": (
115
116
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
Sylvain Gugger's avatar
Sylvain Gugger committed
117
            )
118
119
        },
    )
120
121
122
    use_auth_token: bool = field(
        default=None,
        metadata={
123
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
124
125
        },
    )
126
127
128
129
130
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
131
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
132
133
134
135
                "execute code present on the Hub on your local machine."
            )
        },
    )
136
137
138
139
    low_cpu_mem_usage: bool = field(
        default=False,
        metadata={
            "help": (
140
                "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
141
142
143
144
                "set True will benefit LLM loading time and RAM consumption."
            )
        },
    )
145

146
147
148
149
150
151
    def __post_init__(self):
        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
173
174
175
176
177
178
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
179
180
181
    max_seq_length: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
182
183
184
185
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated."
            )
186
187
188
189
190
191
192
193
194
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    mlm_probability: float = field(
        default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
    )
195
196
197
198
199
200
201
    line_by_line: bool = field(
        default=False,
        metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
202
203
204
205
            "help": (
                "Whether to pad all samples to `max_seq_length`. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch."
            )
206
207
        },
    )
208
209
210
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
211
212
213
214
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
215
216
        },
    )
217
    max_eval_samples: Optional[int] = field(
218
219
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
220
221
222
223
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
224
225
        },
    )
226
    streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
227
228

    def __post_init__(self):
229
230
231
        if self.streaming:
            require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")

232
233
234
235
236
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
237
238
                if extension not in ["csv", "json", "txt"]:
                    raise ValueError("`train_file` should be a csv, a json or a txt file.")
239
240
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
241
242
                if extension not in ["csv", "json", "txt"]:
                    raise ValueError("`validation_file` should be a csv, a json or a txt file.")
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

258
    if model_args.use_auth_token is not None:
259
260
261
262
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
            FutureWarning,
        )
263
264
265
266
        if model_args.token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        model_args.token = model_args.use_auth_token

267
268
269
270
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_mlm", model_args, data_args)

271
272
    # Setup logging
    logging.basicConfig(
273
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
274
        datefmt="%m/%d/%Y %H:%M:%S",
275
        handlers=[logging.StreamHandler(sys.stdout)],
276
    )
277

278
279
280
281
    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

282
283
284
285
286
287
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
288
289
290

    # Log on each process the small summary:
    logger.warning(
291
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
292
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
293
294
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
295
    logger.info(f"Training/evaluation parameters {training_args}")
296

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

312
313
314
315
316
317
318
319
320
321
322
323
324
325
    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
    # behavior (see below)
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
326
        raw_datasets = load_dataset(
327
328
329
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
330
            token=model_args.token,
331
            streaming=data_args.streaming,
332
333
334
        )
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
335
336
337
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
338
                cache_dir=model_args.cache_dir,
339
                token=model_args.token,
340
                streaming=data_args.streaming,
341
            )
342
            raw_datasets["train"] = load_dataset(
343
344
345
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
346
                cache_dir=model_args.cache_dir,
347
                token=model_args.token,
348
                streaming=data_args.streaming,
349
            )
350
351
352
353
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
354
            extension = data_args.train_file.split(".")[-1]
355
        if data_args.validation_file is not None:
356
            data_files["validation"] = data_args.validation_file
357
            extension = data_args.validation_file.split(".")[-1]
358
359
        if extension == "txt":
            extension = "text"
360
361
362
363
        raw_datasets = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
364
            token=model_args.token,
365
        )
366
367
368
369
370
371
372
373

        # If no validation data is there, validation_split_percentage will be used to divide the dataset.
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
374
                token=model_args.token,
375
376
377
378
379
380
            )
            raw_datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
381
                token=model_args.token,
382
383
            )

384
385
386
387
388
389
390
391
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
392
393
394
    config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
395
        "token": model_args.token,
396
        "trust_remote_code": model_args.trust_remote_code,
397
    }
398
    if model_args.config_name:
399
        config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
400
    elif model_args.model_name_or_path:
401
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
402
403
404
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")
405
406
407
        if model_args.config_overrides is not None:
            logger.info(f"Overriding config: {model_args.config_overrides}")
            config.update_from_string(model_args.config_overrides)
408
            logger.info(f"New config: {config}")
409

410
411
412
413
    tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
414
        "token": model_args.token,
415
        "trust_remote_code": model_args.trust_remote_code,
416
    }
417
    if model_args.tokenizer_name:
418
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
419
    elif model_args.model_name_or_path:
420
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
421
422
    else:
        raise ValueError(
423
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
424
425
426
427
428
429
430
431
432
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        model = AutoModelForMaskedLM.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
433
            revision=model_args.model_revision,
434
            token=model_args.token,
435
            trust_remote_code=model_args.trust_remote_code,
436
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
437
438
439
        )
    else:
        logger.info("Training new model from scratch")
440
        model = AutoModelForMaskedLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
441

442
443
444
445
446
    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tokenizer))
447
448
449
450

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
451
        column_names = list(raw_datasets["train"].features)
452
    else:
453
        column_names = list(raw_datasets["validation"].features)
454
455
    text_column_name = "text" if "text" in column_names else column_names[0]

456
457
458
    if data_args.max_seq_length is None:
        max_seq_length = tokenizer.model_max_length
        if max_seq_length > 1024:
459
            logger.warning(
460
461
462
                "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
                " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
                " override this default with `--block_size xxx`."
463
464
465
466
            )
            max_seq_length = 1024
    else:
        if data_args.max_seq_length > tokenizer.model_max_length:
467
            logger.warning(
468
                f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
469
470
471
472
                f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
            )
        max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

473
474
475
476
477
478
    if data_args.line_by_line:
        # When using line_by_line, we just tokenize each nonempty line.
        padding = "max_length" if data_args.pad_to_max_length else False

        def tokenize_function(examples):
            # Remove empty lines
479
480
481
            examples[text_column_name] = [
                line for line in examples[text_column_name] if len(line) > 0 and not line.isspace()
            ]
482
            return tokenizer(
483
                examples[text_column_name],
484
485
                padding=padding,
                truncation=True,
486
                max_length=max_seq_length,
487
488
489
490
                # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
                # receives the `special_tokens_mask`.
                return_special_tokens_mask=True,
            )
491

492
        with training_args.main_process_first(desc="dataset map tokenization"):
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
            if not data_args.streaming:
                tokenized_datasets = raw_datasets.map(
                    tokenize_function,
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    remove_columns=[text_column_name],
                    load_from_cache_file=not data_args.overwrite_cache,
                    desc="Running tokenizer on dataset line_by_line",
                )
            else:
                tokenized_datasets = raw_datasets.map(
                    tokenize_function,
                    batched=True,
                    remove_columns=[text_column_name],
                )
508
509
    else:
        # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
510
511
        # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
        # efficient when it receives the `special_tokens_mask`.
512
        def tokenize_function(examples):
513
            return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
514

515
        with training_args.main_process_first(desc="dataset map tokenization"):
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
            if not data_args.streaming:
                tokenized_datasets = raw_datasets.map(
                    tokenize_function,
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                    desc="Running tokenizer on every text in dataset",
                )
            else:
                tokenized_datasets = raw_datasets.map(
                    tokenize_function,
                    batched=True,
                    remove_columns=column_names,
                )
531
532
533
534
535

        # Main data processing function that will concatenate all texts from our dataset and generate chunks of
        # max_seq_length.
        def group_texts(examples):
            # Concatenate all texts.
536
            concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
537
            total_length = len(concatenated_examples[list(examples.keys())[0]])
538
539
540
            # We drop the small remainder, and if the total_length < max_seq_length  we exclude this batch and return an empty dict.
            # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
            total_length = (total_length // max_seq_length) * max_seq_length
541
542
543
544
545
546
547
548
549
550
551
552
            # Split by chunks of max_len.
            result = {
                k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
                for k, t in concatenated_examples.items()
            }
            return result

        # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
        # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
        # might be slower to preprocess.
        #
        # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
553
        # https://huggingface.co/docs/datasets/process#map
554

555
        with training_args.main_process_first(desc="grouping texts together"):
556
557
558
559
560
561
562
563
564
565
566
567
568
            if not data_args.streaming:
                tokenized_datasets = tokenized_datasets.map(
                    group_texts,
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    load_from_cache_file=not data_args.overwrite_cache,
                    desc=f"Grouping texts in chunks of {max_seq_length}",
                )
            else:
                tokenized_datasets = tokenized_datasets.map(
                    group_texts,
                    batched=True,
                )
569

570
571
572
573
574
    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = tokenized_datasets["train"]
        if data_args.max_train_samples is not None:
575
576
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
577
578
579
580
581

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = tokenized_datasets["validation"]
582
        if data_args.max_eval_samples is not None:
583
584
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
585

586
        def preprocess_logits_for_metrics(logits, labels):
davidleonfdez's avatar
davidleonfdez committed
587
588
589
590
            if isinstance(logits, tuple):
                # Depending on the model and config, logits may contain extra tensors,
                # like past_key_values, but logits always come first
                logits = logits[0]
591
592
            return logits.argmax(dim=-1)

593
        metric = evaluate.load("accuracy")
594
595
596
597
598
599
600
601
602
603
604
605

        def compute_metrics(eval_preds):
            preds, labels = eval_preds
            # preds have the same shape as the labels, after the argmax(-1) has been calculated
            # by preprocess_logits_for_metrics
            labels = labels.reshape(-1)
            preds = preds.reshape(-1)
            mask = labels != -100
            labels = labels[mask]
            preds = preds[mask]
            return metric.compute(predictions=preds, references=labels)

606
607
    # Data collator
    # This one will take care of randomly masking the tokens.
608
609
610
611
612
613
    pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm_probability=data_args.mlm_probability,
        pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
    )
614
615
616
617
618

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
619
620
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
621
622
        tokenizer=tokenizer,
        data_collator=data_collator,
623
624
625
626
        compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics
        if training_args.do_eval and not is_torch_tpu_available()
        else None,
627
628
629
630
    )

    # Training
    if training_args.do_train:
631
632
633
634
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
635
636
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
637
        trainer.save_model()  # Saves the tokenizer too for easy upload
638
        metrics = train_result.metrics
639

640
641
642
643
644
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

645
646
647
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
648

649
650
651
652
    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

653
        metrics = trainer.evaluate()
654

655
656
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
657
658
659
660
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
661
        metrics["perplexity"] = perplexity
662

663
664
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
665

666
667
668
669
670
671
672
673
    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "fill-mask"}
    if data_args.dataset_name is not None:
        kwargs["dataset_tags"] = data_args.dataset_name
        if data_args.dataset_config_name is not None:
            kwargs["dataset_args"] = data_args.dataset_config_name
            kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
        else:
            kwargs["dataset"] = data_args.dataset_name
Sylvain Gugger's avatar
Sylvain Gugger committed
674

675
    if training_args.push_to_hub:
Sylvain Gugger's avatar
Sylvain Gugger committed
676
        trainer.push_to_hub(**kwargs)
677
678
    else:
        trainer.create_model_card(**kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
679

680
681
682
683
684
685
686
687

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()