run_glue.py 20 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Lysandre's avatar
Lysandre committed
15
""" Finetuning the library models for sequence classification on GLUE."""
Sylvain Gugger's avatar
Sylvain Gugger committed
16
# You can also adapt this script on your own text classification task. Pointers for this are left as comments.
thomwolf's avatar
thomwolf committed
17
18
19

import logging
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
20
import random
21
import sys
22
from dataclasses import dataclass, field
Sylvain Gugger's avatar
Sylvain Gugger committed
23
from typing import Optional
thomwolf's avatar
thomwolf committed
24
25

import numpy as np
Sylvain Gugger's avatar
Sylvain Gugger committed
26
from datasets import load_dataset, load_metric
thomwolf's avatar
thomwolf committed
27

Sylvain Gugger's avatar
Sylvain Gugger committed
28
import transformers
29
from transformers import (
Sylvain Gugger's avatar
Sylvain Gugger committed
30
31
32
33
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    EvalPrediction,
34
    HfArgumentParser,
Sylvain Gugger's avatar
Sylvain Gugger committed
35
    PretrainedConfig,
Julien Chaumond's avatar
Julien Chaumond committed
36
    Trainer,
37
    TrainingArguments,
Sylvain Gugger's avatar
Sylvain Gugger committed
38
    default_data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
39
    set_seed,
40
)
Sylvain Gugger's avatar
Sylvain Gugger committed
41
42
from transformers.trainer_utils import is_main_process

Aymeric Augustin's avatar
Aymeric Augustin committed
43

Sylvain Gugger's avatar
Sylvain Gugger committed
44
45
46
47
48
49
50
51
52
53
54
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}
thomwolf's avatar
thomwolf committed
55
56
57

logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
58

Sylvain Gugger's avatar
Sylvain Gugger committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    task_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch."
        },
    )
    train_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the training data."}
    )
    validation_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the validation data."}
    )
96
    test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
99
100
101
102
103
104
105

    def __post_init__(self):
        if self.task_name is not None:
            self.task_name = self.task_name.lower()
            if self.task_name not in task_to_keys.keys():
                raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
        elif self.train_file is None or self.validation_file is None:
            raise ValueError("Need either a GLUE task or a training/validation file.")
        else:
106
107
108
109
110
111
            train_extension = self.train_file.split(".")[-1]
            assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            validation_extension = self.validation_file.split(".")[-1]
            assert (
                validation_extension == train_extension
            ), "`validation_file` should have the same extension (csv or json) as `train_file`."
Sylvain Gugger's avatar
Sylvain Gugger committed
112
113


114
115
116
117
118
119
120
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
Julien Chaumond's avatar
Julien Chaumond committed
121
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
122
    )
123
124
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
125
    )
126
127
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
128
    )
129
    cache_dir: Optional[str] = field(
130
131
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
132
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
133
134
135
136
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
137
138
139
140
141
142
143
144
145
146
147
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )
148
149


150
def main():
Julien Chaumond's avatar
Julien Chaumond committed
151
152
153
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
154

155
156
157
158
159
160
161
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
thomwolf's avatar
thomwolf committed
162

163
    if (
Julien Chaumond's avatar
Julien Chaumond committed
164
165
166
167
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
168
169
    ):
        raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
170
171
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
172
        )
thomwolf's avatar
thomwolf committed
173

thomwolf's avatar
thomwolf committed
174
    # Setup logging
175
176
177
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
Sylvain Gugger's avatar
Sylvain Gugger committed
178
        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
179
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
180
181

    # Log on each process the small summary:
182
    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
183
184
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
185
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
186
187
188
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
189
190
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
Sylvain Gugger's avatar
Sylvain Gugger committed
191
    logger.info(f"Training/evaluation parameters {training_args}")
thomwolf's avatar
thomwolf committed
192

Sylvain Gugger's avatar
Sylvain Gugger committed
193
    # Set seed before initializing model.
Julien Chaumond's avatar
Julien Chaumond committed
194
    set_seed(training_args.seed)
thomwolf's avatar
thomwolf committed
195

Sylvain Gugger's avatar
Sylvain Gugger committed
196
    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
Sylvain Gugger's avatar
Sylvain Gugger committed
197
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
Sylvain Gugger's avatar
Sylvain Gugger committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    #
    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.
    #
    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.task_name is not None:
        # Downloading and loading a dataset from the hub.
        datasets = load_dataset("glue", data_args.task_name)
    else:
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        # Loading a dataset from your local files.
        # CSV/JSON training and evaluation files are needed.
        data_files = {"train": data_args.train_file, "validation": data_args.validation_file}

        # Get the test dataset: you can provide your own CSV/JSON test file (see below)
        # when you use `do_predict` without specifying a GLUE benchmark task.
        if training_args.do_predict:
            if data_args.test_file is not None:
                train_extension = data_args.train_file.split(".")[-1]
                test_extension = data_args.test_file.split(".")[-1]
                assert (
                    test_extension == train_extension
                ), "`test_file` should have the same extension (csv or json) as `train_file`."
                data_files["test"] = data_args.test_file
            else:
                raise ValueError("Need either a GLUE task or a test file for `do_predict`.")

        for key in data_files.keys():
            logger.info(f"load a local file for {key}: {data_files[key]}")

        if data_args.train_file.endswith(".csv"):
            # Loading a dataset from local csv files
            datasets = load_dataset("csv", data_files=data_files)
        else:
            # Loading a dataset from local json files
            datasets = load_dataset("json", data_files=data_files)
Sylvain Gugger's avatar
Sylvain Gugger committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if data_args.task_name is not None:
        is_regression = data_args.task_name == "stsb"
        if not is_regression:
            label_list = datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"]
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
260
261

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
262
    #
Sylvain Gugger's avatar
Sylvain Gugger committed
263
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
Julien Chaumond's avatar
Julien Chaumond committed
264
    # download model & vocab.
265
    config = AutoConfig.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
266
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
267
        num_labels=num_labels,
Julien Chaumond's avatar
Julien Chaumond committed
268
269
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
270
271
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
272
    )
273
    tokenizer = AutoTokenizer.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
274
275
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
Sylvain Gugger's avatar
Sylvain Gugger committed
276
        use_fast=model_args.use_fast_tokenizer,
277
278
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
279
    )
280
    model = AutoModelForSequenceClassification.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
281
282
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
283
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
284
        cache_dir=model_args.cache_dir,
285
286
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
287
    )
thomwolf's avatar
thomwolf committed
288

Sylvain Gugger's avatar
Sylvain Gugger committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    # Preprocessing the datasets
    if data_args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [name for name in datasets["train"].column_names if name != "label"]
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Padding strategy
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False
thomwolf's avatar
thomwolf committed
309

Sylvain Gugger's avatar
Sylvain Gugger committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and data_args.task_name is not None
        and is_regression
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
        else:
            logger.warn(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
327
    elif data_args.task_name is None and not is_regression:
Sylvain Gugger's avatar
Sylvain Gugger committed
328
        label_to_id = {v: i for i, v in enumerate(label_list)}
329

Sylvain Gugger's avatar
Sylvain Gugger committed
330
331
332
333
334
    def preprocess_function(examples):
        # Tokenize the texts
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
335
        result = tokenizer(*args, padding=padding, max_length=data_args.max_seq_length, truncation=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
336
337
338
339
340
341
342
343
344
345

        # Map labels to IDs (not necessary for GLUE tasks)
        if label_to_id is not None and "label" in examples:
            result["label"] = [label_to_id[l] for l in examples["label"]]
        return result

    datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)

    train_dataset = datasets["train"]
    eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
346
    if data_args.task_name is not None or data_args.test_file is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    # Get the metric function
    if data_args.task_name is not None:
        metric = load_metric("glue", data_args.task_name)
    # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from
    # compute_metrics

    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p: EvalPrediction):
        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
        if data_args.task_name is not None:
            result = metric.compute(predictions=preds, references=p.label_ids)
            if len(result) > 1:
                result["combined_score"] = np.mean(list(result.values())).item()
            return result
        elif is_regression:
            return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
        else:
            return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
thomwolf's avatar
thomwolf committed
373

Julien Chaumond's avatar
Julien Chaumond committed
374
375
376
377
378
    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
Sylvain Gugger's avatar
Sylvain Gugger committed
379
380
381
382
383
        eval_dataset=eval_dataset if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
        # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
        data_collator=default_data_collator if data_args.pad_to_max_length else None,
Julien Chaumond's avatar
Julien Chaumond committed
384
    )
thomwolf's avatar
thomwolf committed
385

thomwolf's avatar
thomwolf committed
386
    # Training
Julien Chaumond's avatar
Julien Chaumond committed
387
    if training_args.do_train:
388
        train_result = trainer.train(
Julien Chaumond's avatar
Julien Chaumond committed
389
390
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
391
392
        metrics = train_result.metrics

Sylvain Gugger's avatar
Sylvain Gugger committed
393
        trainer.save_model()  # Saves the tokenizer too for easy upload
thomwolf's avatar
thomwolf committed
394

395
396
397
398
399
400
401
402
403
404
405
        output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
        if trainer.is_world_process_zero():
            with open(output_train_file, "w") as writer:
                logger.info("***** Train results *****")
                for key, value in sorted(metrics.items()):
                    logger.info(f"  {key} = {value}")
                    writer.write(f"{key} = {value}\n")

            # Need to save the state, since Trainer.save_model saves only the tokenizer with the model
            trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))

thomwolf's avatar
thomwolf committed
406
    # Evaluation
407
    eval_results = {}
408
    if training_args.do_eval:
Julien Chaumond's avatar
Julien Chaumond committed
409
410
411
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
Sylvain Gugger's avatar
Sylvain Gugger committed
412
        tasks = [data_args.task_name]
Julien Chaumond's avatar
Julien Chaumond committed
413
414
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
Sylvain Gugger's avatar
Sylvain Gugger committed
415
416
            tasks.append("mnli-mm")
            eval_datasets.append(datasets["validation_mismatched"])
Julien Chaumond's avatar
Julien Chaumond committed
417

Sylvain Gugger's avatar
Sylvain Gugger committed
418
        for eval_dataset, task in zip(eval_datasets, tasks):
419
            eval_result = trainer.evaluate(eval_dataset=eval_dataset)
Julien Chaumond's avatar
Julien Chaumond committed
420

Sylvain Gugger's avatar
Sylvain Gugger committed
421
422
            output_eval_file = os.path.join(training_args.output_dir, f"eval_results_{task}.txt")
            if trainer.is_world_process_zero():
423
                with open(output_eval_file, "w") as writer:
Sylvain Gugger's avatar
Sylvain Gugger committed
424
                    logger.info(f"***** Eval results {task} *****")
425
                    for key, value in sorted(eval_result.items()):
Sylvain Gugger's avatar
Sylvain Gugger committed
426
427
                        logger.info(f"  {key} = {value}")
                        writer.write(f"{key} = {value}\n")
428

429
            eval_results.update(eval_result)
thomwolf's avatar
thomwolf committed
430

431
    if training_args.do_predict:
Sylvain Gugger's avatar
Sylvain Gugger committed
432
433
434
435
        logger.info("*** Test ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        tasks = [data_args.task_name]
436
437
        test_datasets = [test_dataset]
        if data_args.task_name == "mnli":
Sylvain Gugger's avatar
Sylvain Gugger committed
438
439
            tasks.append("mnli-mm")
            test_datasets.append(datasets["test_mismatched"])
440

Sylvain Gugger's avatar
Sylvain Gugger committed
441
442
443
        for test_dataset, task in zip(test_datasets, tasks):
            # Removing the `label` columns because it contains -1 and Trainer won't like that.
            test_dataset.remove_columns_("label")
444
            predictions = trainer.predict(test_dataset=test_dataset).predictions
Sylvain Gugger's avatar
Sylvain Gugger committed
445
            predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
446

Sylvain Gugger's avatar
Sylvain Gugger committed
447
448
            output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt")
            if trainer.is_world_process_zero():
449
                with open(output_test_file, "w") as writer:
Sylvain Gugger's avatar
Sylvain Gugger committed
450
                    logger.info(f"***** Test results {task} *****")
451
452
                    writer.write("index\tprediction\n")
                    for index, item in enumerate(predictions):
Sylvain Gugger's avatar
Sylvain Gugger committed
453
454
                        if is_regression:
                            writer.write(f"{index}\t{item:3.3f}\n")
455
                        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
456
457
                            item = label_list[item]
                            writer.write(f"{index}\t{item}\n")
458
    return eval_results
thomwolf's avatar
thomwolf committed
459
460


Lysandre Debut's avatar
Lysandre Debut committed
461
462
463
464
465
def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


thomwolf's avatar
thomwolf committed
466
467
if __name__ == "__main__":
    main()