run_glue.py 27 KB
Newer Older
1
#!/usr/bin/env python
thomwolf's avatar
thomwolf committed
2
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
3
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
thomwolf's avatar
thomwolf committed
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Lysandre's avatar
Lysandre committed
16
""" Finetuning the library models for sequence classification on GLUE."""
Sylvain Gugger's avatar
Sylvain Gugger committed
17
# You can also adapt this script on your own text classification task. Pointers for this are left as comments.
thomwolf's avatar
thomwolf committed
18
19
20

import logging
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
21
import random
22
import sys
23
import warnings
24
from dataclasses import dataclass, field
Sylvain Gugger's avatar
Sylvain Gugger committed
25
from typing import Optional
thomwolf's avatar
thomwolf committed
26

27
import datasets
28
import evaluate
thomwolf's avatar
thomwolf committed
29
import numpy as np
30
from datasets import load_dataset
thomwolf's avatar
thomwolf committed
31

Sylvain Gugger's avatar
Sylvain Gugger committed
32
import transformers
33
from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
34
35
36
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
37
    DataCollatorWithPadding,
Sylvain Gugger's avatar
Sylvain Gugger committed
38
    EvalPrediction,
39
    HfArgumentParser,
Sylvain Gugger's avatar
Sylvain Gugger committed
40
    PretrainedConfig,
Julien Chaumond's avatar
Julien Chaumond committed
41
    Trainer,
42
    TrainingArguments,
Sylvain Gugger's avatar
Sylvain Gugger committed
43
    default_data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
44
    set_seed,
45
)
46
from transformers.trainer_utils import get_last_checkpoint
47
from transformers.utils import check_min_version, send_example_telemetry
48
from transformers.utils.versions import require_version
Sylvain Gugger's avatar
Sylvain Gugger committed
49

Aymeric Augustin's avatar
Aymeric Augustin committed
50

51
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Sylvain Gugger's avatar
Sylvain Gugger committed
52
check_min_version("4.32.0.dev0")
Lysandre's avatar
Lysandre committed
53

54
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
55

Sylvain Gugger's avatar
Sylvain Gugger committed
56
57
58
59
60
61
62
63
64
65
66
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}
thomwolf's avatar
thomwolf committed
67
68
69

logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
70

Sylvain Gugger's avatar
Sylvain Gugger committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    task_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
    )
85
86
87
88
89
90
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
91
92
93
    max_seq_length: int = field(
        default=128,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
94
95
96
97
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
98
99
100
101
102
103
104
105
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
108
109
            "help": (
                "Whether to pad all samples to `max_seq_length`. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch."
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
110
111
        },
    )
112
113
114
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
115
116
117
118
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
119
120
        },
    )
121
    max_eval_samples: Optional[int] = field(
122
123
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
124
125
126
127
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
128
129
        },
    )
130
    max_predict_samples: Optional[int] = field(
131
132
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
133
134
135
136
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
137
138
        },
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
139
140
141
142
143
144
    train_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the training data."}
    )
    validation_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the validation data."}
    )
145
    test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
Sylvain Gugger's avatar
Sylvain Gugger committed
146
147
148
149
150
151

    def __post_init__(self):
        if self.task_name is not None:
            self.task_name = self.task_name.lower()
            if self.task_name not in task_to_keys.keys():
                raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
152
153
        elif self.dataset_name is not None:
            pass
Sylvain Gugger's avatar
Sylvain Gugger committed
154
        elif self.train_file is None or self.validation_file is None:
155
            raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
Sylvain Gugger's avatar
Sylvain Gugger committed
156
        else:
157
158
159
160
161
162
            train_extension = self.train_file.split(".")[-1]
            assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            validation_extension = self.validation_file.split(".")[-1]
            assert (
                validation_extension == train_extension
            ), "`validation_file` should have the same extension (csv or json) as `train_file`."
Sylvain Gugger's avatar
Sylvain Gugger committed
163
164


165
166
167
168
169
170
171
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
Julien Chaumond's avatar
Julien Chaumond committed
172
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
173
    )
174
175
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
176
    )
177
178
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
179
    )
180
    cache_dir: Optional[str] = field(
181
182
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
183
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
184
185
186
187
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
188
189
190
191
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
192
193
    token: str = field(
        default=None,
194
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
195
            "help": (
196
197
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
Sylvain Gugger's avatar
Sylvain Gugger committed
198
            )
199
200
        },
    )
201
202
203
204
205
206
    use_auth_token: bool = field(
        default=None,
        metadata={
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
        },
    )
207
208
209
210
    ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
    )
211
212


213
def main():
Julien Chaumond's avatar
Julien Chaumond committed
214
215
216
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
217

218
219
220
221
222
223
224
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
thomwolf's avatar
thomwolf committed
225

226
227
228
229
230
231
    if model_args.use_auth_token is not None:
        warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
        if model_args.token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        model_args.token = model_args.use_auth_token

232
233
234
235
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_glue", model_args, data_args)

thomwolf's avatar
thomwolf committed
236
    # Setup logging
237
    logging.basicConfig(
238
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
239
        datefmt="%m/%d/%Y %H:%M:%S",
240
        handlers=[logging.StreamHandler(sys.stdout)],
241
    )
242

243
244
245
246
    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

247
248
249
250
251
252
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
Sylvain Gugger's avatar
Sylvain Gugger committed
253
254

    # Log on each process the small summary:
255
    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
256
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
257
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
258
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
259
    logger.info(f"Training/evaluation parameters {training_args}")
thomwolf's avatar
thomwolf committed
260

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
276
    # Set seed before initializing model.
Julien Chaumond's avatar
Julien Chaumond committed
277
    set_seed(training_args.seed)
thomwolf's avatar
thomwolf committed
278

Sylvain Gugger's avatar
Sylvain Gugger committed
279
    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
Sylvain Gugger's avatar
Sylvain Gugger committed
280
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
Sylvain Gugger's avatar
Sylvain Gugger committed
281
282
283
284
285
286
287
288
289
290
291
292
    #
    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.
    #
    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.task_name is not None:
        # Downloading and loading a dataset from the hub.
293
294
295
296
        raw_datasets = load_dataset(
            "glue",
            data_args.task_name,
            cache_dir=model_args.cache_dir,
297
            token=model_args.token,
298
        )
299
300
    elif data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
301
        raw_datasets = load_dataset(
302
303
304
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
305
            token=model_args.token,
306
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
307
    else:
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        # Loading a dataset from your local files.
        # CSV/JSON training and evaluation files are needed.
        data_files = {"train": data_args.train_file, "validation": data_args.validation_file}

        # Get the test dataset: you can provide your own CSV/JSON test file (see below)
        # when you use `do_predict` without specifying a GLUE benchmark task.
        if training_args.do_predict:
            if data_args.test_file is not None:
                train_extension = data_args.train_file.split(".")[-1]
                test_extension = data_args.test_file.split(".")[-1]
                assert (
                    test_extension == train_extension
                ), "`test_file` should have the same extension (csv or json) as `train_file`."
                data_files["test"] = data_args.test_file
            else:
                raise ValueError("Need either a GLUE task or a test file for `do_predict`.")

        for key in data_files.keys():
            logger.info(f"load a local file for {key}: {data_files[key]}")

        if data_args.train_file.endswith(".csv"):
            # Loading a dataset from local csv files
330
331
332
333
            raw_datasets = load_dataset(
                "csv",
                data_files=data_files,
                cache_dir=model_args.cache_dir,
334
                token=model_args.token,
335
            )
336
337
        else:
            # Loading a dataset from local json files
338
339
340
341
            raw_datasets = load_dataset(
                "json",
                data_files=data_files,
                cache_dir=model_args.cache_dir,
342
                token=model_args.token,
343
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
344
345
346
347
348
349
350
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if data_args.task_name is not None:
        is_regression = data_args.task_name == "stsb"
        if not is_regression:
351
            label_list = raw_datasets["train"].features["label"].names
Sylvain Gugger's avatar
Sylvain Gugger committed
352
353
354
355
356
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
357
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
Sylvain Gugger's avatar
Sylvain Gugger committed
358
359
360
361
362
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
363
            label_list = raw_datasets["train"].unique("label")
Sylvain Gugger's avatar
Sylvain Gugger committed
364
365
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
366
367

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
368
    #
Sylvain Gugger's avatar
Sylvain Gugger committed
369
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
Julien Chaumond's avatar
Julien Chaumond committed
370
    # download model & vocab.
371
    config = AutoConfig.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
372
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
373
        num_labels=num_labels,
Julien Chaumond's avatar
Julien Chaumond committed
374
375
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
376
        revision=model_args.model_revision,
377
        token=model_args.token,
378
    )
379
    tokenizer = AutoTokenizer.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
380
381
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
Sylvain Gugger's avatar
Sylvain Gugger committed
382
        use_fast=model_args.use_fast_tokenizer,
383
        revision=model_args.model_revision,
384
        token=model_args.token,
385
    )
386
    model = AutoModelForSequenceClassification.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
387
388
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
389
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
390
        cache_dir=model_args.cache_dir,
391
        revision=model_args.model_revision,
392
        token=model_args.token,
393
        ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
394
    )
thomwolf's avatar
thomwolf committed
395

396
    # Preprocessing the raw_datasets
Sylvain Gugger's avatar
Sylvain Gugger committed
397
398
399
400
    if data_args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
401
        non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
Sylvain Gugger's avatar
Sylvain Gugger committed
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Padding strategy
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False
thomwolf's avatar
thomwolf committed
416

Sylvain Gugger's avatar
Sylvain Gugger committed
417
418
419
420
421
    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and data_args.task_name is not None
422
        and not is_regression
Sylvain Gugger's avatar
Sylvain Gugger committed
423
424
425
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
426
        if sorted(label_name_to_id.keys()) == sorted(label_list):
427
            label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
Sylvain Gugger's avatar
Sylvain Gugger committed
428
        else:
429
            logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
430
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
431
                f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
Sylvain Gugger's avatar
Sylvain Gugger committed
432
433
                "\nIgnoring the model labels as a result.",
            )
434
    elif data_args.task_name is None and not is_regression:
Sylvain Gugger's avatar
Sylvain Gugger committed
435
        label_to_id = {v: i for i, v in enumerate(label_list)}
436

437
438
439
    if label_to_id is not None:
        model.config.label2id = label_to_id
        model.config.id2label = {id: label for label, id in config.label2id.items()}
440
441
442
    elif data_args.task_name is not None and not is_regression:
        model.config.label2id = {l: i for i, l in enumerate(label_list)}
        model.config.id2label = {id: label for label, id in config.label2id.items()}
443

444
    if data_args.max_seq_length > tokenizer.model_max_length:
445
        logger.warning(
446
447
448
449
450
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

Sylvain Gugger's avatar
Sylvain Gugger committed
451
452
453
454
455
    def preprocess_function(examples):
        # Tokenize the texts
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
456
        result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
457
458
459

        # Map labels to IDs (not necessary for GLUE tasks)
        if label_to_id is not None and "label" in examples:
460
            result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
Sylvain Gugger's avatar
Sylvain Gugger committed
461
462
        return result

463
464
465
466
467
468
469
    with training_args.main_process_first(desc="dataset map pre-processing"):
        raw_datasets = raw_datasets.map(
            preprocess_function,
            batched=True,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )
470
    if training_args.do_train:
471
        if "train" not in raw_datasets:
472
            raise ValueError("--do_train requires a train dataset")
473
        train_dataset = raw_datasets["train"]
474
        if data_args.max_train_samples is not None:
475
476
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
Sylvain Gugger's avatar
Sylvain Gugger committed
477

478
    if training_args.do_eval:
479
        if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
480
            raise ValueError("--do_eval requires a validation dataset")
481
        eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
482
        if data_args.max_eval_samples is not None:
483
484
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
485
486

    if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
487
        if "test" not in raw_datasets and "test_matched" not in raw_datasets:
488
            raise ValueError("--do_predict requires a test dataset")
489
        predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"]
490
        if data_args.max_predict_samples is not None:
491
492
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
Sylvain Gugger's avatar
Sylvain Gugger committed
493
494

    # Log a few random samples from the training set:
495
496
497
    if training_args.do_train:
        for index in random.sample(range(len(train_dataset)), 3):
            logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
498
499
500

    # Get the metric function
    if data_args.task_name is not None:
501
        metric = evaluate.load("glue", data_args.task_name)
502
503
    elif is_regression:
        metric = evaluate.load("mse")
504
    else:
505
        metric = evaluate.load("accuracy")
Sylvain Gugger's avatar
Sylvain Gugger committed
506
507
508
509
510
511

    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p: EvalPrediction):
        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
512
513
514
515
        result = metric.compute(predictions=preds, references=p.label_ids)
        if len(result) > 1:
            result["combined_score"] = np.mean(list(result.values())).item()
        return result
thomwolf's avatar
thomwolf committed
516

517
518
    # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
    # we already did the padding.
519
520
521
522
523
524
525
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    elif training_args.fp16:
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
    else:
        data_collator = None

Julien Chaumond's avatar
Julien Chaumond committed
526
527
528
529
    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
530
        train_dataset=train_dataset if training_args.do_train else None,
Sylvain Gugger's avatar
Sylvain Gugger committed
531
532
533
        eval_dataset=eval_dataset if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
534
        data_collator=data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
535
    )
thomwolf's avatar
thomwolf committed
536

thomwolf's avatar
thomwolf committed
537
    # Training
Julien Chaumond's avatar
Julien Chaumond committed
538
    if training_args.do_train:
539
        checkpoint = None
540
541
542
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
543
544
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
545
        metrics = train_result.metrics
546
547
548
549
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
550

Sylvain Gugger's avatar
Sylvain Gugger committed
551
        trainer.save_model()  # Saves the tokenizer too for easy upload
thomwolf's avatar
thomwolf committed
552

553
554
555
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
556

thomwolf's avatar
thomwolf committed
557
    # Evaluation
558
    if training_args.do_eval:
Julien Chaumond's avatar
Julien Chaumond committed
559
560
561
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
Sylvain Gugger's avatar
Sylvain Gugger committed
562
        tasks = [data_args.task_name]
Julien Chaumond's avatar
Julien Chaumond committed
563
564
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
Sylvain Gugger's avatar
Sylvain Gugger committed
565
            tasks.append("mnli-mm")
566
567
568
569
570
            valid_mm_dataset = raw_datasets["validation_mismatched"]
            if data_args.max_eval_samples is not None:
                max_eval_samples = min(len(valid_mm_dataset), data_args.max_eval_samples)
                valid_mm_dataset = valid_mm_dataset.select(range(max_eval_samples))
            eval_datasets.append(valid_mm_dataset)
571
            combined = {}
Julien Chaumond's avatar
Julien Chaumond committed
572

Sylvain Gugger's avatar
Sylvain Gugger committed
573
        for eval_dataset, task in zip(eval_datasets, tasks):
574
            metrics = trainer.evaluate(eval_dataset=eval_dataset)
Julien Chaumond's avatar
Julien Chaumond committed
575

576
577
578
579
            max_eval_samples = (
                data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
            )
            metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
580

581
582
            if task == "mnli-mm":
                metrics = {k + "_mm": v for k, v in metrics.items()}
583
            if task is not None and "mnli" in task:
584
585
                combined.update(metrics)

586
            trainer.log_metrics("eval", metrics)
587
            trainer.save_metrics("eval", combined if task is not None and "mnli" in task else metrics)
thomwolf's avatar
thomwolf committed
588

589
    if training_args.do_predict:
590
        logger.info("*** Predict ***")
Sylvain Gugger's avatar
Sylvain Gugger committed
591
592
593

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        tasks = [data_args.task_name]
594
        predict_datasets = [predict_dataset]
595
        if data_args.task_name == "mnli":
Sylvain Gugger's avatar
Sylvain Gugger committed
596
            tasks.append("mnli-mm")
597
            predict_datasets.append(raw_datasets["test_mismatched"])
598

599
        for predict_dataset, task in zip(predict_datasets, tasks):
Sylvain Gugger's avatar
Sylvain Gugger committed
600
            # Removing the `label` columns because it contains -1 and Trainer won't like that.
601
            predict_dataset = predict_dataset.remove_columns("label")
602
            predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
Sylvain Gugger's avatar
Sylvain Gugger committed
603
            predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
604

605
            output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
Sylvain Gugger's avatar
Sylvain Gugger committed
606
            if trainer.is_world_process_zero():
607
608
                with open(output_predict_file, "w") as writer:
                    logger.info(f"***** Predict results {task} *****")
609
610
                    writer.write("index\tprediction\n")
                    for index, item in enumerate(predictions):
Sylvain Gugger's avatar
Sylvain Gugger committed
611
612
                        if is_regression:
                            writer.write(f"{index}\t{item:3.3f}\n")
613
                        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
614
615
                            item = label_list[item]
                            writer.write(f"{index}\t{item}\n")
thomwolf's avatar
thomwolf committed
616

617
618
619
620
621
622
    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
    if data_args.task_name is not None:
        kwargs["language"] = "en"
        kwargs["dataset_tags"] = "glue"
        kwargs["dataset_args"] = data_args.task_name
        kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}"
Sylvain Gugger's avatar
Sylvain Gugger committed
623

624
    if training_args.push_to_hub:
Sylvain Gugger's avatar
Sylvain Gugger committed
625
        trainer.push_to_hub(**kwargs)
626
627
    else:
        trainer.create_model_card(**kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
628

thomwolf's avatar
thomwolf committed
629

Lysandre Debut's avatar
Lysandre Debut committed
630
631
632
633
634
def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


thomwolf's avatar
thomwolf committed
635
636
if __name__ == "__main__":
    main()