run_glue.py 26.6 KB
Newer Older
1
#!/usr/bin/env python
thomwolf's avatar
thomwolf committed
2
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
3
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
thomwolf's avatar
thomwolf committed
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Lysandre's avatar
Lysandre committed
16
""" Finetuning the library models for sequence classification on GLUE."""
Sylvain Gugger's avatar
Sylvain Gugger committed
17
# You can also adapt this script on your own text classification task. Pointers for this are left as comments.
thomwolf's avatar
thomwolf committed
18
19
20

import logging
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
21
import random
22
import sys
23
from dataclasses import dataclass, field
Sylvain Gugger's avatar
Sylvain Gugger committed
24
from typing import Optional
thomwolf's avatar
thomwolf committed
25

26
import datasets
thomwolf's avatar
thomwolf committed
27
import numpy as np
28
from datasets import load_dataset
thomwolf's avatar
thomwolf committed
29

30
import evaluate
Sylvain Gugger's avatar
Sylvain Gugger committed
31
import transformers
32
from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
33
34
35
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
36
    DataCollatorWithPadding,
Sylvain Gugger's avatar
Sylvain Gugger committed
37
    EvalPrediction,
38
    HfArgumentParser,
Sylvain Gugger's avatar
Sylvain Gugger committed
39
    PretrainedConfig,
Julien Chaumond's avatar
Julien Chaumond committed
40
    Trainer,
41
    TrainingArguments,
Sylvain Gugger's avatar
Sylvain Gugger committed
42
    default_data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
43
    set_seed,
44
)
45
from transformers.trainer_utils import get_last_checkpoint
46
from transformers.utils import check_min_version, send_example_telemetry
47
from transformers.utils.versions import require_version
Sylvain Gugger's avatar
Sylvain Gugger committed
48

Aymeric Augustin's avatar
Aymeric Augustin committed
49

50
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Sylvain Gugger's avatar
Sylvain Gugger committed
51
check_min_version("4.27.0.dev0")
Lysandre's avatar
Lysandre committed
52

53
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
54

Sylvain Gugger's avatar
Sylvain Gugger committed
55
56
57
58
59
60
61
62
63
64
65
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}
thomwolf's avatar
thomwolf committed
66
67
68

logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
69

Sylvain Gugger's avatar
Sylvain Gugger committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    task_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
    )
84
85
86
87
88
89
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
90
91
92
    max_seq_length: int = field(
        default=128,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
93
94
95
96
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
99
100
101
102
103
104
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
105
106
107
108
            "help": (
                "Whether to pad all samples to `max_seq_length`. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch."
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
109
110
        },
    )
111
112
113
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
114
115
116
117
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
118
119
        },
    )
120
    max_eval_samples: Optional[int] = field(
121
122
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
123
124
125
126
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
127
128
        },
    )
129
    max_predict_samples: Optional[int] = field(
130
131
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
132
133
134
135
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
136
137
        },
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
138
139
140
141
142
143
    train_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the training data."}
    )
    validation_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the validation data."}
    )
144
    test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
Sylvain Gugger's avatar
Sylvain Gugger committed
145
146
147
148
149
150

    def __post_init__(self):
        if self.task_name is not None:
            self.task_name = self.task_name.lower()
            if self.task_name not in task_to_keys.keys():
                raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
151
152
        elif self.dataset_name is not None:
            pass
Sylvain Gugger's avatar
Sylvain Gugger committed
153
        elif self.train_file is None or self.validation_file is None:
154
            raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
Sylvain Gugger's avatar
Sylvain Gugger committed
155
        else:
156
157
158
159
160
161
            train_extension = self.train_file.split(".")[-1]
            assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            validation_extension = self.validation_file.split(".")[-1]
            assert (
                validation_extension == train_extension
            ), "`validation_file` should have the same extension (csv or json) as `train_file`."
Sylvain Gugger's avatar
Sylvain Gugger committed
162
163


164
165
166
167
168
169
170
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
Julien Chaumond's avatar
Julien Chaumond committed
171
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
172
    )
173
174
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
175
    )
176
177
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
178
    )
179
    cache_dir: Optional[str] = field(
180
181
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
182
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
183
184
185
186
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
187
188
189
190
191
192
193
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
194
            "help": (
195
                "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
Sylvain Gugger's avatar
Sylvain Gugger committed
196
197
                "with private models)."
            )
198
199
        },
    )
200
201
202
203
    ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
    )
204
205


206
def main():
Julien Chaumond's avatar
Julien Chaumond committed
207
208
209
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
210

211
212
213
214
215
216
217
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
thomwolf's avatar
thomwolf committed
218

219
220
221
222
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_glue", model_args, data_args)

thomwolf's avatar
thomwolf committed
223
    # Setup logging
224
    logging.basicConfig(
225
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
226
        datefmt="%m/%d/%Y %H:%M:%S",
227
        handlers=[logging.StreamHandler(sys.stdout)],
228
    )
229
230
231
232
233
234
235

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
Sylvain Gugger's avatar
Sylvain Gugger committed
236
237

    # Log on each process the small summary:
238
    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
239
240
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
241
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
242
    logger.info(f"Training/evaluation parameters {training_args}")
thomwolf's avatar
thomwolf committed
243

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
259
    # Set seed before initializing model.
Julien Chaumond's avatar
Julien Chaumond committed
260
    set_seed(training_args.seed)
thomwolf's avatar
thomwolf committed
261

Sylvain Gugger's avatar
Sylvain Gugger committed
262
    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
Sylvain Gugger's avatar
Sylvain Gugger committed
263
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
Sylvain Gugger's avatar
Sylvain Gugger committed
264
265
266
267
268
269
270
271
272
273
274
275
    #
    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.
    #
    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.task_name is not None:
        # Downloading and loading a dataset from the hub.
276
277
278
279
280
281
        raw_datasets = load_dataset(
            "glue",
            data_args.task_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
282
283
    elif data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
284
        raw_datasets = load_dataset(
285
286
287
288
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
289
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
290
    else:
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        # Loading a dataset from your local files.
        # CSV/JSON training and evaluation files are needed.
        data_files = {"train": data_args.train_file, "validation": data_args.validation_file}

        # Get the test dataset: you can provide your own CSV/JSON test file (see below)
        # when you use `do_predict` without specifying a GLUE benchmark task.
        if training_args.do_predict:
            if data_args.test_file is not None:
                train_extension = data_args.train_file.split(".")[-1]
                test_extension = data_args.test_file.split(".")[-1]
                assert (
                    test_extension == train_extension
                ), "`test_file` should have the same extension (csv or json) as `train_file`."
                data_files["test"] = data_args.test_file
            else:
                raise ValueError("Need either a GLUE task or a test file for `do_predict`.")

        for key in data_files.keys():
            logger.info(f"load a local file for {key}: {data_files[key]}")

        if data_args.train_file.endswith(".csv"):
            # Loading a dataset from local csv files
313
314
315
316
317
318
            raw_datasets = load_dataset(
                "csv",
                data_files=data_files,
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
319
320
        else:
            # Loading a dataset from local json files
321
322
323
324
325
326
            raw_datasets = load_dataset(
                "json",
                data_files=data_files,
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
327
328
329
330
331
332
333
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if data_args.task_name is not None:
        is_regression = data_args.task_name == "stsb"
        if not is_regression:
334
            label_list = raw_datasets["train"].features["label"].names
Sylvain Gugger's avatar
Sylvain Gugger committed
335
336
337
338
339
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
340
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
Sylvain Gugger's avatar
Sylvain Gugger committed
341
342
343
344
345
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
346
            label_list = raw_datasets["train"].unique("label")
Sylvain Gugger's avatar
Sylvain Gugger committed
347
348
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
349
350

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
351
    #
Sylvain Gugger's avatar
Sylvain Gugger committed
352
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
Julien Chaumond's avatar
Julien Chaumond committed
353
    # download model & vocab.
354
    config = AutoConfig.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
355
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
356
        num_labels=num_labels,
Julien Chaumond's avatar
Julien Chaumond committed
357
358
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
359
360
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
361
    )
362
    tokenizer = AutoTokenizer.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
363
364
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
Sylvain Gugger's avatar
Sylvain Gugger committed
365
        use_fast=model_args.use_fast_tokenizer,
366
367
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
368
    )
369
    model = AutoModelForSequenceClassification.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
370
371
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
372
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
373
        cache_dir=model_args.cache_dir,
374
375
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
376
        ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
377
    )
thomwolf's avatar
thomwolf committed
378

379
    # Preprocessing the raw_datasets
Sylvain Gugger's avatar
Sylvain Gugger committed
380
381
382
383
    if data_args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
384
        non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
Sylvain Gugger's avatar
Sylvain Gugger committed
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Padding strategy
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False
thomwolf's avatar
thomwolf committed
399

Sylvain Gugger's avatar
Sylvain Gugger committed
400
401
402
403
404
    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and data_args.task_name is not None
405
        and not is_regression
Sylvain Gugger's avatar
Sylvain Gugger committed
406
407
408
409
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
410
            label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
Sylvain Gugger's avatar
Sylvain Gugger committed
411
        else:
412
            logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
413
414
415
416
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
417
    elif data_args.task_name is None and not is_regression:
Sylvain Gugger's avatar
Sylvain Gugger committed
418
        label_to_id = {v: i for i, v in enumerate(label_list)}
419

420
421
422
    if label_to_id is not None:
        model.config.label2id = label_to_id
        model.config.id2label = {id: label for label, id in config.label2id.items()}
423
424
425
    elif data_args.task_name is not None and not is_regression:
        model.config.label2id = {l: i for i, l in enumerate(label_list)}
        model.config.id2label = {id: label for label, id in config.label2id.items()}
426

427
    if data_args.max_seq_length > tokenizer.model_max_length:
428
        logger.warning(
429
430
431
432
433
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

Sylvain Gugger's avatar
Sylvain Gugger committed
434
435
436
437
438
    def preprocess_function(examples):
        # Tokenize the texts
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
439
        result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
440
441
442

        # Map labels to IDs (not necessary for GLUE tasks)
        if label_to_id is not None and "label" in examples:
443
            result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
Sylvain Gugger's avatar
Sylvain Gugger committed
444
445
        return result

446
447
448
449
450
451
452
    with training_args.main_process_first(desc="dataset map pre-processing"):
        raw_datasets = raw_datasets.map(
            preprocess_function,
            batched=True,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )
453
    if training_args.do_train:
454
        if "train" not in raw_datasets:
455
            raise ValueError("--do_train requires a train dataset")
456
        train_dataset = raw_datasets["train"]
457
        if data_args.max_train_samples is not None:
458
459
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
Sylvain Gugger's avatar
Sylvain Gugger committed
460

461
    if training_args.do_eval:
462
        if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
463
            raise ValueError("--do_eval requires a validation dataset")
464
        eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
465
        if data_args.max_eval_samples is not None:
466
467
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
468
469

    if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
470
        if "test" not in raw_datasets and "test_matched" not in raw_datasets:
471
            raise ValueError("--do_predict requires a test dataset")
472
        predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"]
473
        if data_args.max_predict_samples is not None:
474
475
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
Sylvain Gugger's avatar
Sylvain Gugger committed
476
477

    # Log a few random samples from the training set:
478
479
480
    if training_args.do_train:
        for index in random.sample(range(len(train_dataset)), 3):
            logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
481
482
483

    # Get the metric function
    if data_args.task_name is not None:
484
        metric = evaluate.load("glue", data_args.task_name)
485
    else:
486
        metric = evaluate.load("accuracy")
Sylvain Gugger's avatar
Sylvain Gugger committed
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501

    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p: EvalPrediction):
        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
        if data_args.task_name is not None:
            result = metric.compute(predictions=preds, references=p.label_ids)
            if len(result) > 1:
                result["combined_score"] = np.mean(list(result.values())).item()
            return result
        elif is_regression:
            return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
        else:
            return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
thomwolf's avatar
thomwolf committed
502

503
504
    # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
    # we already did the padding.
505
506
507
508
509
510
511
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    elif training_args.fp16:
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
    else:
        data_collator = None

Julien Chaumond's avatar
Julien Chaumond committed
512
513
514
515
    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
516
        train_dataset=train_dataset if training_args.do_train else None,
Sylvain Gugger's avatar
Sylvain Gugger committed
517
518
519
        eval_dataset=eval_dataset if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
520
        data_collator=data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
521
    )
thomwolf's avatar
thomwolf committed
522

thomwolf's avatar
thomwolf committed
523
    # Training
Julien Chaumond's avatar
Julien Chaumond committed
524
    if training_args.do_train:
525
        checkpoint = None
526
527
528
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
529
530
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
531
        metrics = train_result.metrics
532
533
534
535
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
536

Sylvain Gugger's avatar
Sylvain Gugger committed
537
        trainer.save_model()  # Saves the tokenizer too for easy upload
thomwolf's avatar
thomwolf committed
538

539
540
541
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
542

thomwolf's avatar
thomwolf committed
543
    # Evaluation
544
    if training_args.do_eval:
Julien Chaumond's avatar
Julien Chaumond committed
545
546
547
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
Sylvain Gugger's avatar
Sylvain Gugger committed
548
        tasks = [data_args.task_name]
Julien Chaumond's avatar
Julien Chaumond committed
549
550
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
Sylvain Gugger's avatar
Sylvain Gugger committed
551
            tasks.append("mnli-mm")
552
553
554
555
556
            valid_mm_dataset = raw_datasets["validation_mismatched"]
            if data_args.max_eval_samples is not None:
                max_eval_samples = min(len(valid_mm_dataset), data_args.max_eval_samples)
                valid_mm_dataset = valid_mm_dataset.select(range(max_eval_samples))
            eval_datasets.append(valid_mm_dataset)
557
            combined = {}
Julien Chaumond's avatar
Julien Chaumond committed
558

Sylvain Gugger's avatar
Sylvain Gugger committed
559
        for eval_dataset, task in zip(eval_datasets, tasks):
560
            metrics = trainer.evaluate(eval_dataset=eval_dataset)
Julien Chaumond's avatar
Julien Chaumond committed
561

562
563
564
565
            max_eval_samples = (
                data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
            )
            metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
566

567
568
            if task == "mnli-mm":
                metrics = {k + "_mm": v for k, v in metrics.items()}
569
            if task is not None and "mnli" in task:
570
571
                combined.update(metrics)

572
            trainer.log_metrics("eval", metrics)
573
            trainer.save_metrics("eval", combined if task is not None and "mnli" in task else metrics)
thomwolf's avatar
thomwolf committed
574

575
    if training_args.do_predict:
576
        logger.info("*** Predict ***")
Sylvain Gugger's avatar
Sylvain Gugger committed
577
578
579

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        tasks = [data_args.task_name]
580
        predict_datasets = [predict_dataset]
581
        if data_args.task_name == "mnli":
Sylvain Gugger's avatar
Sylvain Gugger committed
582
            tasks.append("mnli-mm")
583
            predict_datasets.append(raw_datasets["test_mismatched"])
584

585
        for predict_dataset, task in zip(predict_datasets, tasks):
Sylvain Gugger's avatar
Sylvain Gugger committed
586
            # Removing the `label` columns because it contains -1 and Trainer won't like that.
587
            predict_dataset = predict_dataset.remove_columns("label")
588
            predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
Sylvain Gugger's avatar
Sylvain Gugger committed
589
            predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
590

591
            output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
Sylvain Gugger's avatar
Sylvain Gugger committed
592
            if trainer.is_world_process_zero():
593
594
                with open(output_predict_file, "w") as writer:
                    logger.info(f"***** Predict results {task} *****")
595
596
                    writer.write("index\tprediction\n")
                    for index, item in enumerate(predictions):
Sylvain Gugger's avatar
Sylvain Gugger committed
597
598
                        if is_regression:
                            writer.write(f"{index}\t{item:3.3f}\n")
599
                        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
600
601
                            item = label_list[item]
                            writer.write(f"{index}\t{item}\n")
thomwolf's avatar
thomwolf committed
602

603
604
605
606
607
608
    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
    if data_args.task_name is not None:
        kwargs["language"] = "en"
        kwargs["dataset_tags"] = "glue"
        kwargs["dataset_args"] = data_args.task_name
        kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}"
Sylvain Gugger's avatar
Sylvain Gugger committed
609

610
    if training_args.push_to_hub:
Sylvain Gugger's avatar
Sylvain Gugger committed
611
        trainer.push_to_hub(**kwargs)
612
613
    else:
        trainer.create_model_card(**kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
614

thomwolf's avatar
thomwolf committed
615

Lysandre Debut's avatar
Lysandre Debut committed
616
617
618
619
620
def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


thomwolf's avatar
thomwolf committed
621
622
if __name__ == "__main__":
    main()