run_glue.py 26.3 KB
Newer Older
1
#!/usr/bin/env python
thomwolf's avatar
thomwolf committed
2
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
3
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
thomwolf's avatar
thomwolf committed
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Lysandre's avatar
Lysandre committed
16
""" Finetuning the library models for sequence classification on GLUE."""
Sylvain Gugger's avatar
Sylvain Gugger committed
17
# You can also adapt this script on your own text classification task. Pointers for this are left as comments.
thomwolf's avatar
thomwolf committed
18
19
20

import logging
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
21
import random
22
import sys
23
from dataclasses import dataclass, field
Sylvain Gugger's avatar
Sylvain Gugger committed
24
from typing import Optional
thomwolf's avatar
thomwolf committed
25

26
import datasets
thomwolf's avatar
thomwolf committed
27
import numpy as np
Sylvain Gugger's avatar
Sylvain Gugger committed
28
from datasets import load_dataset, load_metric
thomwolf's avatar
thomwolf committed
29

Sylvain Gugger's avatar
Sylvain Gugger committed
30
import transformers
31
from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
32
33
34
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
35
    DataCollatorWithPadding,
Sylvain Gugger's avatar
Sylvain Gugger committed
36
    EvalPrediction,
37
    HfArgumentParser,
Sylvain Gugger's avatar
Sylvain Gugger committed
38
    PretrainedConfig,
Julien Chaumond's avatar
Julien Chaumond committed
39
    Trainer,
40
    TrainingArguments,
Sylvain Gugger's avatar
Sylvain Gugger committed
41
    default_data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
42
    set_seed,
43
)
44
from transformers.trainer_utils import get_last_checkpoint
45
from transformers.utils import check_min_version, send_example_telemetry
46
from transformers.utils.versions import require_version
Sylvain Gugger's avatar
Sylvain Gugger committed
47

Aymeric Augustin's avatar
Aymeric Augustin committed
48

49
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Lysandre Debut's avatar
Lysandre Debut committed
50
check_min_version("4.20.0.dev0")
Lysandre's avatar
Lysandre committed
51

52
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
53

Sylvain Gugger's avatar
Sylvain Gugger committed
54
55
56
57
58
59
60
61
62
63
64
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}
thomwolf's avatar
thomwolf committed
65
66
67

logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
68

Sylvain Gugger's avatar
Sylvain Gugger committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    task_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
    )
83
84
85
86
87
88
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
89
90
91
    max_seq_length: int = field(
        default=128,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
92
93
94
95
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
96
97
98
99
100
101
102
103
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
104
105
106
107
            "help": (
                "Whether to pad all samples to `max_seq_length`. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch."
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
108
109
        },
    )
110
111
112
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
113
114
115
116
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
117
118
        },
    )
119
    max_eval_samples: Optional[int] = field(
120
121
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
122
123
124
125
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
126
127
        },
    )
128
    max_predict_samples: Optional[int] = field(
129
130
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
131
132
133
134
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
135
136
        },
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
137
138
139
140
141
142
    train_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the training data."}
    )
    validation_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the validation data."}
    )
143
    test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
Sylvain Gugger's avatar
Sylvain Gugger committed
144
145
146
147
148
149

    def __post_init__(self):
        if self.task_name is not None:
            self.task_name = self.task_name.lower()
            if self.task_name not in task_to_keys.keys():
                raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
150
151
        elif self.dataset_name is not None:
            pass
Sylvain Gugger's avatar
Sylvain Gugger committed
152
        elif self.train_file is None or self.validation_file is None:
153
            raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
Sylvain Gugger's avatar
Sylvain Gugger committed
154
        else:
155
156
157
158
159
160
            train_extension = self.train_file.split(".")[-1]
            assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            validation_extension = self.validation_file.split(".")[-1]
            assert (
                validation_extension == train_extension
            ), "`validation_file` should have the same extension (csv or json) as `train_file`."
Sylvain Gugger's avatar
Sylvain Gugger committed
161
162


163
164
165
166
167
168
169
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
Julien Chaumond's avatar
Julien Chaumond committed
170
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
171
    )
172
173
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
174
    )
175
176
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
177
    )
178
    cache_dir: Optional[str] = field(
179
180
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
181
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
182
183
184
185
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
186
187
188
189
190
191
192
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
193
194
195
196
            "help": (
                "Will use the token generated when running `transformers-cli login` (necessary to use this script "
                "with private models)."
            )
197
198
        },
    )
199
200
201
202
    ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
    )
203
204


205
def main():
Julien Chaumond's avatar
Julien Chaumond committed
206
207
208
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
209

210
211
212
213
214
215
216
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
thomwolf's avatar
thomwolf committed
217

218
219
220
221
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_glue", model_args, data_args)

thomwolf's avatar
thomwolf committed
222
    # Setup logging
223
    logging.basicConfig(
224
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
225
        datefmt="%m/%d/%Y %H:%M:%S",
226
        handlers=[logging.StreamHandler(sys.stdout)],
227
    )
228
229
230
231
232
233
234

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
Sylvain Gugger's avatar
Sylvain Gugger committed
235
236

    # Log on each process the small summary:
237
    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
238
239
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
240
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
241
    logger.info(f"Training/evaluation parameters {training_args}")
thomwolf's avatar
thomwolf committed
242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
258
    # Set seed before initializing model.
Julien Chaumond's avatar
Julien Chaumond committed
259
    set_seed(training_args.seed)
thomwolf's avatar
thomwolf committed
260

Sylvain Gugger's avatar
Sylvain Gugger committed
261
    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
Sylvain Gugger's avatar
Sylvain Gugger committed
262
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
Sylvain Gugger's avatar
Sylvain Gugger committed
263
264
265
266
267
268
269
270
271
272
273
274
    #
    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.
    #
    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.task_name is not None:
        # Downloading and loading a dataset from the hub.
275
276
277
278
279
280
        raw_datasets = load_dataset(
            "glue",
            data_args.task_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
281
282
    elif data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
283
        raw_datasets = load_dataset(
284
285
286
287
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
288
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
289
    else:
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        # Loading a dataset from your local files.
        # CSV/JSON training and evaluation files are needed.
        data_files = {"train": data_args.train_file, "validation": data_args.validation_file}

        # Get the test dataset: you can provide your own CSV/JSON test file (see below)
        # when you use `do_predict` without specifying a GLUE benchmark task.
        if training_args.do_predict:
            if data_args.test_file is not None:
                train_extension = data_args.train_file.split(".")[-1]
                test_extension = data_args.test_file.split(".")[-1]
                assert (
                    test_extension == train_extension
                ), "`test_file` should have the same extension (csv or json) as `train_file`."
                data_files["test"] = data_args.test_file
            else:
                raise ValueError("Need either a GLUE task or a test file for `do_predict`.")

        for key in data_files.keys():
            logger.info(f"load a local file for {key}: {data_files[key]}")

        if data_args.train_file.endswith(".csv"):
            # Loading a dataset from local csv files
312
313
314
315
316
317
            raw_datasets = load_dataset(
                "csv",
                data_files=data_files,
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
318
319
        else:
            # Loading a dataset from local json files
320
321
322
323
324
325
            raw_datasets = load_dataset(
                "json",
                data_files=data_files,
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
326
327
328
329
330
331
332
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if data_args.task_name is not None:
        is_regression = data_args.task_name == "stsb"
        if not is_regression:
333
            label_list = raw_datasets["train"].features["label"].names
Sylvain Gugger's avatar
Sylvain Gugger committed
334
335
336
337
338
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
339
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
Sylvain Gugger's avatar
Sylvain Gugger committed
340
341
342
343
344
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
345
            label_list = raw_datasets["train"].unique("label")
Sylvain Gugger's avatar
Sylvain Gugger committed
346
347
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
348
349

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
350
    #
Sylvain Gugger's avatar
Sylvain Gugger committed
351
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
Julien Chaumond's avatar
Julien Chaumond committed
352
    # download model & vocab.
353
    config = AutoConfig.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
354
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
355
        num_labels=num_labels,
Julien Chaumond's avatar
Julien Chaumond committed
356
357
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
358
359
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
360
    )
361
    tokenizer = AutoTokenizer.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
362
363
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
Sylvain Gugger's avatar
Sylvain Gugger committed
364
        use_fast=model_args.use_fast_tokenizer,
365
366
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
367
    )
368
    model = AutoModelForSequenceClassification.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
369
370
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
371
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
372
        cache_dir=model_args.cache_dir,
373
374
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
375
        ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
376
    )
thomwolf's avatar
thomwolf committed
377

378
    # Preprocessing the raw_datasets
Sylvain Gugger's avatar
Sylvain Gugger committed
379
380
381
382
    if data_args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
383
        non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
Sylvain Gugger's avatar
Sylvain Gugger committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Padding strategy
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False
thomwolf's avatar
thomwolf committed
398

Sylvain Gugger's avatar
Sylvain Gugger committed
399
400
401
402
403
    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and data_args.task_name is not None
404
        and not is_regression
Sylvain Gugger's avatar
Sylvain Gugger committed
405
406
407
408
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
409
            label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
Sylvain Gugger's avatar
Sylvain Gugger committed
410
        else:
411
            logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
412
413
414
415
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
416
    elif data_args.task_name is None and not is_regression:
Sylvain Gugger's avatar
Sylvain Gugger committed
417
        label_to_id = {v: i for i, v in enumerate(label_list)}
418

419
420
421
    if label_to_id is not None:
        model.config.label2id = label_to_id
        model.config.id2label = {id: label for label, id in config.label2id.items()}
422
423
424
    elif data_args.task_name is not None and not is_regression:
        model.config.label2id = {l: i for i, l in enumerate(label_list)}
        model.config.id2label = {id: label for label, id in config.label2id.items()}
425

426
    if data_args.max_seq_length > tokenizer.model_max_length:
427
        logger.warning(
428
429
430
431
432
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

Sylvain Gugger's avatar
Sylvain Gugger committed
433
434
435
436
437
    def preprocess_function(examples):
        # Tokenize the texts
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
438
        result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
439
440
441

        # Map labels to IDs (not necessary for GLUE tasks)
        if label_to_id is not None and "label" in examples:
442
            result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
Sylvain Gugger's avatar
Sylvain Gugger committed
443
444
        return result

445
446
447
448
449
450
451
    with training_args.main_process_first(desc="dataset map pre-processing"):
        raw_datasets = raw_datasets.map(
            preprocess_function,
            batched=True,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )
452
    if training_args.do_train:
453
        if "train" not in raw_datasets:
454
            raise ValueError("--do_train requires a train dataset")
455
        train_dataset = raw_datasets["train"]
456
        if data_args.max_train_samples is not None:
457
458
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
Sylvain Gugger's avatar
Sylvain Gugger committed
459

460
    if training_args.do_eval:
461
        if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
462
            raise ValueError("--do_eval requires a validation dataset")
463
        eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
464
        if data_args.max_eval_samples is not None:
465
466
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
467
468

    if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
469
        if "test" not in raw_datasets and "test_matched" not in raw_datasets:
470
            raise ValueError("--do_predict requires a test dataset")
471
        predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"]
472
        if data_args.max_predict_samples is not None:
473
474
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
Sylvain Gugger's avatar
Sylvain Gugger committed
475
476

    # Log a few random samples from the training set:
477
478
479
    if training_args.do_train:
        for index in random.sample(range(len(train_dataset)), 3):
            logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
480
481
482
483

    # Get the metric function
    if data_args.task_name is not None:
        metric = load_metric("glue", data_args.task_name)
484
485
    else:
        metric = load_metric("accuracy")
Sylvain Gugger's avatar
Sylvain Gugger committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500

    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p: EvalPrediction):
        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
        if data_args.task_name is not None:
            result = metric.compute(predictions=preds, references=p.label_ids)
            if len(result) > 1:
                result["combined_score"] = np.mean(list(result.values())).item()
            return result
        elif is_regression:
            return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
        else:
            return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
thomwolf's avatar
thomwolf committed
501

502
503
    # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
    # we already did the padding.
504
505
506
507
508
509
510
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    elif training_args.fp16:
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
    else:
        data_collator = None

Julien Chaumond's avatar
Julien Chaumond committed
511
512
513
514
    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
515
        train_dataset=train_dataset if training_args.do_train else None,
Sylvain Gugger's avatar
Sylvain Gugger committed
516
517
518
        eval_dataset=eval_dataset if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
519
        data_collator=data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
520
    )
thomwolf's avatar
thomwolf committed
521

thomwolf's avatar
thomwolf committed
522
    # Training
Julien Chaumond's avatar
Julien Chaumond committed
523
    if training_args.do_train:
524
        checkpoint = None
525
526
527
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
528
529
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
530
        metrics = train_result.metrics
531
532
533
534
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
535

Sylvain Gugger's avatar
Sylvain Gugger committed
536
        trainer.save_model()  # Saves the tokenizer too for easy upload
thomwolf's avatar
thomwolf committed
537

538
539
540
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
541

thomwolf's avatar
thomwolf committed
542
    # Evaluation
543
    if training_args.do_eval:
Julien Chaumond's avatar
Julien Chaumond committed
544
545
546
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
Sylvain Gugger's avatar
Sylvain Gugger committed
547
        tasks = [data_args.task_name]
Julien Chaumond's avatar
Julien Chaumond committed
548
549
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
Sylvain Gugger's avatar
Sylvain Gugger committed
550
            tasks.append("mnli-mm")
551
            eval_datasets.append(raw_datasets["validation_mismatched"])
552
            combined = {}
Julien Chaumond's avatar
Julien Chaumond committed
553

Sylvain Gugger's avatar
Sylvain Gugger committed
554
        for eval_dataset, task in zip(eval_datasets, tasks):
555
            metrics = trainer.evaluate(eval_dataset=eval_dataset)
Julien Chaumond's avatar
Julien Chaumond committed
556

557
558
559
560
            max_eval_samples = (
                data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
            )
            metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
561

562
563
            if task == "mnli-mm":
                metrics = {k + "_mm": v for k, v in metrics.items()}
564
            if task is not None and "mnli" in task:
565
566
                combined.update(metrics)

567
            trainer.log_metrics("eval", metrics)
568
            trainer.save_metrics("eval", combined if task is not None and "mnli" in task else metrics)
thomwolf's avatar
thomwolf committed
569

570
    if training_args.do_predict:
571
        logger.info("*** Predict ***")
Sylvain Gugger's avatar
Sylvain Gugger committed
572
573
574

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        tasks = [data_args.task_name]
575
        predict_datasets = [predict_dataset]
576
        if data_args.task_name == "mnli":
Sylvain Gugger's avatar
Sylvain Gugger committed
577
            tasks.append("mnli-mm")
578
            predict_datasets.append(raw_datasets["test_mismatched"])
579

580
        for predict_dataset, task in zip(predict_datasets, tasks):
Sylvain Gugger's avatar
Sylvain Gugger committed
581
            # Removing the `label` columns because it contains -1 and Trainer won't like that.
582
            predict_dataset = predict_dataset.remove_columns("label")
583
            predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
Sylvain Gugger's avatar
Sylvain Gugger committed
584
            predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
585

586
            output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
Sylvain Gugger's avatar
Sylvain Gugger committed
587
            if trainer.is_world_process_zero():
588
589
                with open(output_predict_file, "w") as writer:
                    logger.info(f"***** Predict results {task} *****")
590
591
                    writer.write("index\tprediction\n")
                    for index, item in enumerate(predictions):
Sylvain Gugger's avatar
Sylvain Gugger committed
592
593
                        if is_regression:
                            writer.write(f"{index}\t{item:3.3f}\n")
594
                        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
595
596
                            item = label_list[item]
                            writer.write(f"{index}\t{item}\n")
thomwolf's avatar
thomwolf committed
597

598
599
600
601
602
603
    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
    if data_args.task_name is not None:
        kwargs["language"] = "en"
        kwargs["dataset_tags"] = "glue"
        kwargs["dataset_args"] = data_args.task_name
        kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}"
Sylvain Gugger's avatar
Sylvain Gugger committed
604

605
    if training_args.push_to_hub:
Sylvain Gugger's avatar
Sylvain Gugger committed
606
        trainer.push_to_hub(**kwargs)
607
608
    else:
        trainer.create_model_card(**kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
609

thomwolf's avatar
thomwolf committed
610

Lysandre Debut's avatar
Lysandre Debut committed
611
612
613
614
615
def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


thomwolf's avatar
thomwolf committed
616
617
if __name__ == "__main__":
    main()