"tests/flaubert/test_modeling_flaubert.py" did not exist on "d4c2cb402d6674211726fd5f4803d1090664e438"
run_xnli.py 18.3 KB
Newer Older
1
#!/usr/bin/env python
VictorSanh's avatar
VictorSanh committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
17
""" Finetuning multi-lingual models on XNLI (e.g. Bert, DistilBERT, XLM).
18
    Adapted from `examples/text-classification/run_glue.py`"""
VictorSanh's avatar
VictorSanh committed
19
20
21
22

import logging
import os
import random
23
import sys
24
import warnings
25
26
from dataclasses import dataclass, field
from typing import Optional
VictorSanh's avatar
VictorSanh committed
27

28
import datasets
29
import evaluate
VictorSanh's avatar
VictorSanh committed
30
import numpy as np
31
from datasets import load_dataset
VictorSanh's avatar
VictorSanh committed
32

33
import transformers
34
from transformers import (
35
36
37
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
38
39
40
41
42
43
44
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
45
)
46
from transformers.trainer_utils import get_last_checkpoint
47
from transformers.utils import check_min_version, send_example_telemetry
48
from transformers.utils.versions import require_version
Aymeric Augustin's avatar
Aymeric Augustin committed
49

VictorSanh's avatar
VictorSanh committed
50

51
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Amy Roberts's avatar
Amy Roberts committed
52
check_min_version("4.38.0.dev0")
Lysandre's avatar
Lysandre committed
53

54
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
55

VictorSanh's avatar
VictorSanh committed
56
57
58
logger = logging.getLogger(__name__)


59
60
61
62
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
VictorSanh's avatar
VictorSanh committed
63

64
65
66
67
    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """
VictorSanh's avatar
VictorSanh committed
68

69
70
71
    max_seq_length: Optional[int] = field(
        default=128,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
72
73
74
75
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
76
77
        },
    )
78
79
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
80
    )
81
82
83
    pad_to_max_length: bool = field(
        default=True,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
84
85
86
87
            "help": (
                "Whether to pad all samples to `max_seq_length`. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch."
            )
88
        },
89
    )
90
91
92
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
93
94
95
96
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
97
98
        },
    )
99
    max_eval_samples: Optional[int] = field(
100
101
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
102
103
104
105
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
106
107
        },
    )
108
    max_predict_samples: Optional[int] = field(
109
110
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
111
112
113
114
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
115
116
        },
    )
VictorSanh's avatar
VictorSanh committed
117
118


119
120
121
122
123
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
VictorSanh's avatar
VictorSanh committed
124

125
126
    model_name_or_path: str = field(
        default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
127
    )
128
129
    language: str = field(
        default=None, metadata={"help": "Evaluation language. Also train language if `train_language` is set to None."}
130
    )
131
132
    train_language: Optional[str] = field(
        default=None, metadata={"help": "Train language if it is different from the evaluation language."}
133
    )
134
135
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
136
    )
137
138
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
139
    )
140
    cache_dir: Optional[str] = field(
141
        default=None,
142
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
143
    )
144
145
146
    do_lower_case: Optional[bool] = field(
        default=False,
        metadata={"help": "arg to indicate if tokenizer should do lower case in AutoTokenizer.from_pretrained()"},
147
    )
148
149
150
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
151
    )
152
153
154
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
155
    )
156
157
    token: str = field(
        default=None,
158
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
159
            "help": (
160
161
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
Sylvain Gugger's avatar
Sylvain Gugger committed
162
            )
163
        },
164
    )
165
166
167
    use_auth_token: bool = field(
        default=None,
        metadata={
168
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
169
170
        },
    )
171
172
173
174
175
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
176
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
177
178
179
180
                "execute code present on the Hub on your local machine."
            )
        },
    )
181
182
183
184
    ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
    )
185

186
187
188
189
190
191
192
193
194

def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

195
    if model_args.use_auth_token is not None:
196
197
198
199
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
            FutureWarning,
        )
200
201
202
203
        if model_args.token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        model_args.token = model_args.use_auth_token

204
205
206
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_xnli", model_args)
VictorSanh's avatar
VictorSanh committed
207
208

    # Setup logging
209
    logging.basicConfig(
210
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
211
        datefmt="%m/%d/%Y %H:%M:%S",
212
        handlers=[logging.StreamHandler(sys.stdout)],
213
    )
214

215
216
217
218
    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

219
220
221
222
223
224
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
225
226

    # Log on each process the small summary:
227
    logger.warning(
228
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
229
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
230
    )
231
232
    logger.info(f"Training/evaluation parameters {training_args}")

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

248
249
250
251
252
253
    # Set seed before initializing model.
    set_seed(training_args.seed)

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    # Downloading and loading xnli dataset from the hub.
254
255
    if training_args.do_train:
        if model_args.train_language is None:
256
257
258
259
260
            train_dataset = load_dataset(
                "xnli",
                model_args.language,
                split="train",
                cache_dir=model_args.cache_dir,
261
                token=model_args.token,
262
            )
263
        else:
264
            train_dataset = load_dataset(
265
266
267
268
                "xnli",
                model_args.train_language,
                split="train",
                cache_dir=model_args.cache_dir,
269
                token=model_args.token,
270
            )
271
272
273
        label_list = train_dataset.features["label"].names

    if training_args.do_eval:
274
275
276
277
278
        eval_dataset = load_dataset(
            "xnli",
            model_args.language,
            split="validation",
            cache_dir=model_args.cache_dir,
279
            token=model_args.token,
280
        )
281
282
283
        label_list = eval_dataset.features["label"].names

    if training_args.do_predict:
284
285
286
287
288
        predict_dataset = load_dataset(
            "xnli",
            model_args.language,
            split="test",
            cache_dir=model_args.cache_dir,
289
            token=model_args.token,
290
        )
291
        label_list = predict_dataset.features["label"].names
292
293

    # Labels
VictorSanh's avatar
VictorSanh committed
294
295
296
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
297
298
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
299
    config = AutoConfig.from_pretrained(
300
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
301
        num_labels=num_labels,
302
303
        id2label={str(i): label for i, label in enumerate(label_list)},
        label2id={label: i for i, label in enumerate(label_list)},
304
305
306
        finetuning_task="xnli",
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
307
        token=model_args.token,
308
        trust_remote_code=model_args.trust_remote_code,
309
    )
310
    tokenizer = AutoTokenizer.from_pretrained(
311
312
313
314
315
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        do_lower_case=model_args.do_lower_case,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
316
        token=model_args.token,
317
        trust_remote_code=model_args.trust_remote_code,
318
    )
319
    model = AutoModelForSequenceClassification.from_pretrained(
320
321
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
322
        config=config,
323
324
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
325
        token=model_args.token,
326
        trust_remote_code=model_args.trust_remote_code,
327
        ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
328
    )
VictorSanh's avatar
VictorSanh committed
329

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    # Preprocessing the datasets
    # Padding strategy
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False

    def preprocess_function(examples):
        # Tokenize the texts
        return tokenizer(
            examples["premise"],
            examples["hypothesis"],
            padding=padding,
            max_length=data_args.max_seq_length,
            truncation=True,
        )
VictorSanh's avatar
VictorSanh committed
347

348
349
    if training_args.do_train:
        if data_args.max_train_samples is not None:
350
351
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
352
353
354
355
356
357
358
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
359
360
361
        # Log a few random samples from the training set:
        for index in random.sample(range(len(train_dataset)), 3):
            logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
362
363

    if training_args.do_eval:
364
        if data_args.max_eval_samples is not None:
365
366
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
367
368
369
370
371
372
373
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
VictorSanh's avatar
VictorSanh committed
374

375
    if training_args.do_predict:
376
        if data_args.max_predict_samples is not None:
377
378
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
379
380
381
382
383
384
385
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function,
                batched=True,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )
386
387

    # Get the metric function
388
    metric = evaluate.load("xnli", cache_dir=model_args.cache_dir)
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408

    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p: EvalPrediction):
        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = np.argmax(preds, axis=1)
        return metric.compute(predictions=preds, references=p.label_ids)

    # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    elif training_args.fp16:
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
    else:
        data_collator = None

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
409
        train_dataset=train_dataset if training_args.do_train else None,
410
411
412
413
414
        eval_dataset=eval_dataset if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
VictorSanh's avatar
VictorSanh committed
415
416

    # Training
417
    if training_args.do_train:
418
        checkpoint = None
419
420
421
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
422
423
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
424
        metrics = train_result.metrics
425
426
427
428
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
VictorSanh's avatar
VictorSanh committed
429

430
        trainer.save_model()  # Saves the tokenizer too for easy upload
431

432
433
434
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
VictorSanh's avatar
VictorSanh committed
435

436
437
438
    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
439
440
        metrics = trainer.evaluate(eval_dataset=eval_dataset)

441
442
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
443

444
445
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
VictorSanh's avatar
VictorSanh committed
446

447
448
449
    # Prediction
    if training_args.do_predict:
        logger.info("*** Predict ***")
450
        predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
451

452
453
454
455
        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
456

457
458
        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)
459
460

        predictions = np.argmax(predictions, axis=1)
461
        output_predict_file = os.path.join(training_args.output_dir, "predictions.txt")
462
        if trainer.is_world_process_zero():
463
            with open(output_predict_file, "w") as writer:
464
465
466
467
468
                writer.write("index\tprediction\n")
                for index, item in enumerate(predictions):
                    item = label_list[item]
                    writer.write(f"{index}\t{item}\n")

VictorSanh's avatar
VictorSanh committed
469
470
471

if __name__ == "__main__":
    main()