file_utils.py 81 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2020 The HuggingFace Team, the AllenNLP library authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
14
"""
Sylvain Gugger's avatar
Sylvain Gugger committed
15
16
Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at
https://github.com/allenai/allennlp.
thomwolf's avatar
thomwolf committed
17
"""
18
import copy
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import fnmatch
20
import functools
21
import importlib.util
Julien Chaumond's avatar
Julien Chaumond committed
22
import io
thomwolf's avatar
thomwolf committed
23
24
import json
import os
25
import re
26
import shutil
27
import subprocess
Aymeric Augustin's avatar
Aymeric Augustin committed
28
import sys
29
import tarfile
thomwolf's avatar
thomwolf committed
30
import tempfile
31
import types
32
from collections import OrderedDict, UserDict
Aymeric Augustin's avatar
Aymeric Augustin committed
33
from contextlib import contextmanager
34
from dataclasses import fields
35
from enum import Enum
36
from functools import partial, wraps
thomwolf's avatar
thomwolf committed
37
from hashlib import sha256
38
from pathlib import Path
39
from types import ModuleType
40
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
Aymeric Augustin's avatar
Aymeric Augustin committed
41
from urllib.parse import urlparse
42
from uuid import uuid4
43
from zipfile import ZipFile, is_zipfile
thomwolf's avatar
thomwolf committed
44

45
import numpy as np
46
from packaging import version
47
48
from tqdm.auto import tqdm

Aymeric Augustin's avatar
Aymeric Augustin committed
49
50
import requests
from filelock import FileLock
Sylvain Gugger's avatar
Sylvain Gugger committed
51
from huggingface_hub import HfApi, HfFolder, Repository
52
from transformers.utils.versions import importlib_metadata
Aymeric Augustin's avatar
Aymeric Augustin committed
53

54
from . import __version__
Lysandre Debut's avatar
Lysandre Debut committed
55
from .utils import logging
thomwolf's avatar
thomwolf committed
56

Lysandre's avatar
Lysandre committed
57

Lysandre Debut's avatar
Lysandre Debut committed
58
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
thomwolf's avatar
thomwolf committed
59

60
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
61
62
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})

63
64
65
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
66

67
68
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
    _torch_available = importlib.util.find_spec("torch") is not None
69
    if _torch_available:
70
71
72
73
74
75
76
77
78
79
80
81
82
        try:
            _torch_version = importlib_metadata.version("torch")
            logger.info(f"PyTorch version {_torch_version} available.")
        except importlib_metadata.PackageNotFoundError:
            _torch_available = False
else:
    logger.info("Disabling PyTorch because USE_TF is set")
    _torch_available = False


if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
    _tf_available = importlib.util.find_spec("tensorflow") is not None
    if _tf_available:
83
84
85
86
87
88
89
90
        candidates = (
            "tensorflow",
            "tensorflow-cpu",
            "tensorflow-gpu",
            "tf-nightly",
            "tf-nightly-cpu",
            "tf-nightly-gpu",
            "intel-tensorflow",
91
            "intel-tensorflow-avx512",
92
            "tensorflow-rocm",
Julien Plu's avatar
Julien Plu committed
93
            "tensorflow-macos",
94
95
        )
        _tf_version = None
96
        # For the metadata, we have to look for both tensorflow and tensorflow-cpu
97
        for pkg in candidates:
98
            try:
99
100
                _tf_version = importlib_metadata.version(pkg)
                break
101
            except importlib_metadata.PackageNotFoundError:
102
103
                pass
        _tf_available = _tf_version is not None
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    if _tf_available:
        if version.parse(_tf_version) < version.parse("2"):
            logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
            _tf_available = False
        else:
            logger.info(f"TensorFlow version {_tf_version} available.")
else:
    logger.info("Disabling Tensorflow because USE_TORCH is set")
    _tf_available = False


if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
    _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
    if _flax_available:
        try:
            _jax_version = importlib_metadata.version("jax")
            _flax_version = importlib_metadata.version("flax")
            logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
        except importlib_metadata.PackageNotFoundError:
            _flax_available = False
else:
    _flax_available = False


_datasets_available = importlib.util.find_spec("datasets") is not None
Patrick von Platen's avatar
Patrick von Platen committed
129
try:
130
131
132
133
134
135
136
137
    # Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
    # AND checking it has an author field in the metadata that is HuggingFace.
    _ = importlib_metadata.version("datasets")
    _datasets_metadata = importlib_metadata.metadata("datasets")
    if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
        _datasets_available = False
except importlib_metadata.PackageNotFoundError:
    _datasets_available = False
Patrick von Platen's avatar
Patrick von Platen committed
138

Ola Piktus's avatar
Ola Piktus committed
139

140
141
142
143
144
145
146
147
_detectron2_available = importlib.util.find_spec("detectron2") is not None
try:
    _detectron2_version = importlib_metadata.version("detectron2")
    logger.debug(f"Successfully imported detectron2 version {_detectron2_version}")
except importlib_metadata.PackageNotFoundError:
    _detectron2_available = False


148
_faiss_available = importlib.util.find_spec("faiss") is not None
Ola Piktus's avatar
Ola Piktus committed
149
try:
150
151
152
    _faiss_version = importlib_metadata.version("faiss")
    logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
Patrick von Platen's avatar
Patrick von Platen committed
153
154
155
156
157
    try:
        _faiss_version = importlib_metadata.version("faiss-cpu")
        logger.debug(f"Successfully imported faiss version {_faiss_version}")
    except importlib_metadata.PackageNotFoundError:
        _faiss_available = False
Ola Piktus's avatar
Ola Piktus committed
158

159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
coloredlogs = importlib.util.find_spec("coloredlogs") is not None
try:
    _coloredlogs_available = importlib_metadata.version("coloredlogs")
    logger.debug(f"Successfully imported sympy version {_coloredlogs_available}")
except importlib_metadata.PackageNotFoundError:
    _coloredlogs_available = False


sympy_available = importlib.util.find_spec("sympy") is not None
try:
    _sympy_available = importlib_metadata.version("sympy")
    logger.debug(f"Successfully imported sympy version {_sympy_available}")
except importlib_metadata.PackageNotFoundError:
    _sympy_available = False


_keras2onnx_available = importlib.util.find_spec("keras2onnx") is not None
try:
    _keras2onnx_version = importlib_metadata.version("keras2onnx")
    logger.debug(f"Successfully imported keras2onnx version {_keras2onnx_version}")
except importlib_metadata.PackageNotFoundError:
    _keras2onnx_available = False

_onnx_available = importlib.util.find_spec("onnxruntime") is not None
184
185
186
187
188
189
190
try:
    _onxx_version = importlib_metadata.version("onnx")
    logger.debug(f"Successfully imported onnx version {_onxx_version}")
except importlib_metadata.PackageNotFoundError:
    _onnx_available = False


191
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
192
try:
193
    _scatter_version = importlib_metadata.version("torch_scatter")
194
195
    logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
except importlib_metadata.PackageNotFoundError:
NielsRogge's avatar
NielsRogge committed
196
197
198
    _scatter_available = False


Patrick von Platen's avatar
Patrick von Platen committed
199
200
201
202
203
204
205
_soundfile_available = importlib.util.find_spec("soundfile") is not None
try:
    _soundfile_version = importlib_metadata.version("soundfile")
    logger.debug(f"Successfully imported soundfile version {_soundfile_version}")
except importlib_metadata.PackageNotFoundError:
    _soundfile_available = False

206

NielsRogge's avatar
NielsRogge committed
207
208
209
210
211
212
213
214
_timm_available = importlib.util.find_spec("timm") is not None
try:
    _timm_version = importlib_metadata.version("timm")
    logger.debug(f"Successfully imported timm version {_timm_version}")
except importlib_metadata.PackageNotFoundError:
    _timm_available = False


215
_torchaudio_available = importlib.util.find_spec("torchaudio") is not None
Suraj Patil's avatar
Suraj Patil committed
216
217
try:
    _torchaudio_version = importlib_metadata.version("torchaudio")
218
    logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}")
Suraj Patil's avatar
Suraj Patil committed
219
220
221
except importlib_metadata.PackageNotFoundError:
    _torchaudio_available = False

Patrick von Platen's avatar
Patrick von Platen committed
222

223
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
224
225
226
227
228
229
230
231
232
233
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser(
    os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
default_cache_path = os.path.join(hf_cache_home, "transformers")

# Onetime move from the old location to the new one if no ENV variable has been set.
if (
    os.path.isdir(old_default_cache_path)
234
    and not os.path.isdir(default_cache_path)
235
236
237
238
    and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
    and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
    and "TRANSFORMERS_CACHE" not in os.environ
):
239
    logger.warning(
240
241
242
243
244
245
246
        "In Transformers v4.0.0, the default path to cache downloaded models changed from "
        "'~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have overridden "
        "and '~/.cache/torch/transformers' is a directory that exists, we're moving it to "
        "'~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should "
        "only see this message once."
    )
    shutil.move(old_default_cache_path, default_cache_path)
247

248
249
250
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
251
SESSION_ID = uuid4().hex
252
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES
thomwolf's avatar
thomwolf committed
253

254
WEIGHTS_NAME = "pytorch_model.bin"
255
256
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
257
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
258
CONFIG_NAME = "config.json"
259
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
260
MODEL_CARD_NAME = "modelcard.json"
Thomas Wolf's avatar
Thomas Wolf committed
261

262
263
SENTENCEPIECE_UNDERLINE = "▁"
SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE  # Kept for backward compatibility
Lysandre's avatar
Lysandre committed
264

265
266
267
MULTIPLE_CHOICE_DUMMY_INPUTS = [
    [[0, 1, 0, 1], [1, 0, 0, 1]]
] * 2  # Needs to have 0s and 1s only since XLM uses it for langs too.
268
269
270
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]

271
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
Julien Chaumond's avatar
Julien Chaumond committed
272
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
Sylvain Gugger's avatar
Sylvain Gugger committed
273
274
275
276
277
278

_staging_mode = os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
_default_endpoint = "https://moon-staging.huggingface.co" if _staging_mode else "https://huggingface.co"

HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", _default_endpoint)
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
Julien Chaumond's avatar
Julien Chaumond committed
279

280
281
282
283
PRESET_MIRROR_DICT = {
    "tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
    "bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
}
284

285
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
286
TORCH_FX_REQUIRED_VERSION = version.parse("1.8")
287
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
Thomas Wolf's avatar
Thomas Wolf committed
288

289
290
291
292
293
294
295
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False


def is_offline_mode():
    return _is_offline_mode


thomwolf's avatar
thomwolf committed
296
297
298
def is_torch_available():
    return _torch_available

299

300
301
302
303
304
305
306
307
308
def is_torch_cuda_available():
    if is_torch_available():
        import torch

        return torch.cuda.is_available()
    else:
        return False


309
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
310
if _torch_available:
311
312
313
314
315
    torch_version = version.parse(importlib_metadata.version("torch"))
    _torch_fx_available = (torch_version.major, torch_version.minor) == (
        TORCH_FX_REQUIRED_VERSION.major,
        TORCH_FX_REQUIRED_VERSION.minor,
    )
316

317
318
    _torch_onnx_dict_inputs_support_available = torch_version >= TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION

319
320
321
322
323

def is_torch_fx_available():
    return _torch_fx_available


324
325
326
327
def is_torch_onnx_dict_inputs_support_available():
    return _torch_onnx_dict_inputs_support_available


thomwolf's avatar
thomwolf committed
328
329
330
def is_tf_available():
    return _tf_available

331

332
333
334
335
336
337
338
339
def is_coloredlogs_available():
    return _coloredlogs_available


def is_keras2onnx_available():
    return _keras2onnx_available


340
341
342
343
def is_onnx_available():
    return _onnx_available


344
345
346
347
def is_flax_available():
    return _flax_available


348
def is_torch_tpu_available():
349
350
351
352
353
354
355
356
    if not _torch_available:
        return False
    # This test is probably enough, but just in case, we unpack a bit.
    if importlib.util.find_spec("torch_xla") is None:
        return False
    if importlib.util.find_spec("torch_xla.core") is None:
        return False
    return importlib.util.find_spec("torch_xla.core.xla_model") is not None
357
358


359
360
def is_datasets_available():
    return _datasets_available
361
362


363
364
365
366
def is_detectron2_available():
    return _detectron2_available


yujun's avatar
yujun committed
367
368
369
370
def is_rjieba_available():
    return importlib.util.find_spec("rjieba") is not None


Patrick von Platen's avatar
Patrick von Platen committed
371
def is_psutil_available():
372
    return importlib.util.find_spec("psutil") is not None
Patrick von Platen's avatar
Patrick von Platen committed
373
374
375


def is_py3nvml_available():
376
    return importlib.util.find_spec("py3nvml") is not None
Patrick von Platen's avatar
Patrick von Platen committed
377
378
379


def is_apex_available():
380
    return importlib.util.find_spec("apex") is not None
Patrick von Platen's avatar
Patrick von Platen committed
381
382


Ola Piktus's avatar
Ola Piktus committed
383
384
385
386
def is_faiss_available():
    return _faiss_available


NielsRogge's avatar
NielsRogge committed
387
388
389
390
def is_scipy_available():
    return importlib.util.find_spec("scipy") is not None


391
def is_sklearn_available():
392
393
    if importlib.util.find_spec("sklearn") is None:
        return False
NielsRogge's avatar
NielsRogge committed
394
    return is_scipy_available() and importlib.util.find_spec("sklearn.metrics")
395
396
397


def is_sentencepiece_available():
398
    return importlib.util.find_spec("sentencepiece") is not None
399
400


401
def is_protobuf_available():
402
403
404
    if importlib.util.find_spec("google") is None:
        return False
    return importlib.util.find_spec("google.protobuf") is not None
405
406


407
def is_tokenizers_available():
408
    return importlib.util.find_spec("tokenizers") is not None
409
410


411
412
413
414
def is_vision_available():
    return importlib.util.find_spec("PIL") is not None


415
416
417
418
def is_pytesseract_available():
    return importlib.util.find_spec("pytesseract") is not None


419
def is_in_notebook():
420
421
422
423
424
425
426
427
428
429
430
    try:
        # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
        get_ipython = sys.modules["IPython"].get_ipython
        if "IPKernelApp" not in get_ipython().config:
            raise ImportError("console")
        if "VSCODE_PID" in os.environ:
            raise ImportError("vscode")

        return importlib.util.find_spec("IPython") is not None
    except (AttributeError, ImportError, KeyError):
        return False
431
432


NielsRogge's avatar
NielsRogge committed
433
434
435
436
437
def is_scatter_available():
    return _scatter_available


def is_pandas_available():
438
    return importlib.util.find_spec("pandas") is not None
NielsRogge's avatar
NielsRogge committed
439
440


Sylvain Gugger's avatar
Sylvain Gugger committed
441
def is_sagemaker_dp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
442
443
444
445
446
447
448
449
    # Get the sagemaker specific env variable.
    sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
    try:
        # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
        sagemaker_params = json.loads(sagemaker_params)
        if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
            return False
    except json.JSONDecodeError:
Sylvain Gugger's avatar
Sylvain Gugger committed
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        return False
    # Lastly, check if the `smdistributed` module is present.
    return importlib.util.find_spec("smdistributed") is not None


def is_sagemaker_mp_enabled():
    # Get the sagemaker specific mp parameters from smp_options variable.
    smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
    try:
        # Parse it and check the field "partitions" is included, it is required for model parallel.
        smp_options = json.loads(smp_options)
        if "partitions" not in smp_options:
            return False
    except json.JSONDecodeError:
        return False

    # Get the sagemaker specific framework parameters from mpi_options variable.
    mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
    try:
        # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
        mpi_options = json.loads(mpi_options)
        if not mpi_options.get("sagemaker_mpi_enabled", False):
            return False
    except json.JSONDecodeError:
Sylvain Gugger's avatar
Sylvain Gugger committed
474
475
476
477
478
        return False
    # Lastly, check if the `smdistributed` module is present.
    return importlib.util.find_spec("smdistributed") is not None


479
def is_training_run_on_sagemaker():
480
    return "SAGEMAKER_JOB_NAME" in os.environ
481
482


Patrick von Platen's avatar
Patrick von Platen committed
483
484
485
486
def is_soundfile_availble():
    return _soundfile_available


NielsRogge's avatar
NielsRogge committed
487
488
489
490
def is_timm_available():
    return _timm_available


Suraj Patil's avatar
Suraj Patil committed
491
492
493
494
def is_torchaudio_available():
    return _torchaudio_available


495
496
497
498
499
def is_speech_available():
    # For now this depends on torchaudio but the exact dependency might evolve in the future.
    return _torchaudio_available


500
501
502
503
504
505
506
507
508
509
510
511
512
def torch_only_method(fn):
    def wrapper(*args, **kwargs):
        if not _torch_available:
            raise ImportError(
                "You need to install pytorch to use this method or class, "
                "or activate it with environment variables USE_TORCH=1 and USE_TF=0."
            )
        else:
            return fn(*args, **kwargs)

    return wrapper


513
# docstyle-ignore
514
DATASETS_IMPORT_ERROR = """
515
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
```
pip install datasets
```
In a notebook or a colab, you can install it by executing a cell with
```
!pip install datasets
```
then restarting your kernel.

Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
that python file if that's the case.
"""


531
# docstyle-ignore
532
TOKENIZERS_IMPORT_ERROR = """
533
{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:
534
535
536
537
538
539
540
541
542
543
```
pip install tokenizers
```
In a notebook or a colab, you can install it by executing a cell with
```
!pip install tokenizers
```
"""


544
# docstyle-ignore
545
SENTENCEPIECE_IMPORT_ERROR = """
546
{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
547
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
548
that match your environment.
549
550
551
"""


552
553
554
555
556
557
558
559
# docstyle-ignore
PROTOBUF_IMPORT_ERROR = """
{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
that match your environment.
"""


560
# docstyle-ignore
561
FAISS_IMPORT_ERROR = """
562
{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the
563
installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
564
that match your environment.
565
566
567
"""


568
# docstyle-ignore
569
PYTORCH_IMPORT_ERROR = """
570
571
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
572
573
574
"""


575
# docstyle-ignore
576
SKLEARN_IMPORT_ERROR = """
577
{0} requires the scikit-learn library but it was not found in your environment. You can install it with:
578
579
580
581
582
583
584
585
586
587
```
pip install -U scikit-learn
```
In a notebook or a colab, you can install it by executing a cell with
```
!pip install -U scikit-learn
```
"""


588
# docstyle-ignore
589
TENSORFLOW_IMPORT_ERROR = """
590
591
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
592
593
594
"""


595
596
597
598
599
600
601
602
# docstyle-ignore
DETECTRON2_IMPORT_ERROR = """
{0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones
that match your environment.
"""


603
# docstyle-ignore
604
FLAX_IMPORT_ERROR = """
605
606
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.
607
608
609
"""


NielsRogge's avatar
NielsRogge committed
610
611
612
613
614
615
616
# docstyle-ignore
SCATTER_IMPORT_ERROR = """
{0} requires the torch-scatter library but it was not found in your environment. You can install it with pip as
explained here: https://github.com/rusty1s/pytorch_scatter.
"""


617
618
619
620
621
622
623
# docstyle-ignore
PANDAS_IMPORT_ERROR = """
{0} requires the pandas library but it was not found in your environment. You can install it with pip as
explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
"""


NielsRogge's avatar
NielsRogge committed
624
625
626
627
628
629
630
# docstyle-ignore
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip:
`pip install scipy`
"""


631
632
633
634
635
636
# docstyle-ignore
SPEECH_IMPORT_ERROR = """
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
`pip install torchaudio`
"""

NielsRogge's avatar
NielsRogge committed
637
638
639
640
641
# docstyle-ignore
TIMM_IMPORT_ERROR = """
{0} requires the timm library but it was not found in your environment. You can install it with pip:
`pip install timm`
"""
642

643
644
645
646
647
648
649
# docstyle-ignore
VISION_IMPORT_ERROR = """
{0} requires the PIL library but it was not found in your environment. You can install it with pip:
`pip install pillow`
"""


650
651
652
653
654
655
656
# docstyle-ignore
PYTESSERACT_IMPORT_ERROR = """
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
`pip install pytesseract`
"""


657
658
659
BACKENDS_MAPPING = OrderedDict(
    [
        ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
660
        ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
661
662
663
664
        ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
        ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
        ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
        ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
665
        ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
666
667
668
669
670
        ("scatter", (is_scatter_available, SCATTER_IMPORT_ERROR)),
        ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
        ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
        ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
        ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
NielsRogge's avatar
NielsRogge committed
671
        ("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
672
        ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
673
674
        ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
        ("vision", (is_vision_available, VISION_IMPORT_ERROR)),
NielsRogge's avatar
NielsRogge committed
675
        ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
676
677
    ]
)
NielsRogge's avatar
NielsRogge committed
678

679

680
681
682
def requires_backends(obj, backends):
    if not isinstance(backends, (list, tuple)):
        backends = [backends]
683

684
    name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
685
686
    if not all(BACKENDS_MAPPING[backend][0]() for backend in backends):
        raise ImportError("".join([BACKENDS_MAPPING[backend][1].format(name) for backend in backends]))
687
688


Aymeric Augustin's avatar
Aymeric Augustin committed
689
690
def add_start_docstrings(*docstr):
    def docstring_decorator(fn):
691
692
693
694
695
696
        fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
        return fn

    return docstring_decorator


697
def add_start_docstrings_to_model_forward(*docstr):
698
    def docstring_decorator(fn):
699
700
        class_name = f":class:`~transformers.{fn.__qualname__.split('.')[0]}`"
        intro = f"   The {class_name} forward method, overrides the :func:`__call__` special method."
Lysandre's avatar
Lysandre committed
701
702
        note = r"""

703
    .. note::
Sylvain Gugger's avatar
Sylvain Gugger committed
704
705
706
        Although the recipe for forward pass needs to be defined within this function, one should call the
        :class:`Module` instance afterwards instead of this since the former takes care of running the pre and post
        processing steps while the latter silently ignores them.
707
708
        """
        fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
Aymeric Augustin's avatar
Aymeric Augustin committed
709
        return fn
710

Aymeric Augustin's avatar
Aymeric Augustin committed
711
    return docstring_decorator
712

713

Aymeric Augustin's avatar
Aymeric Augustin committed
714
715
716
717
def add_end_docstrings(*docstr):
    def docstring_decorator(fn):
        fn.__doc__ = fn.__doc__ + "".join(docstr)
        return fn
718

Aymeric Augustin's avatar
Aymeric Augustin committed
719
    return docstring_decorator
thomwolf's avatar
thomwolf committed
720

721

Sylvain Gugger's avatar
Sylvain Gugger committed
722
PT_RETURN_INTRODUCTION = r"""
723
    Returns:
724
725
726
        :class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)`: A :class:`~{full_output_type}` or a tuple of
        :obj:`torch.FloatTensor` (if ``return_dict=False`` is passed or when ``config.return_dict=False``) comprising
        various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs.
727

728
729
730
"""


Sylvain Gugger's avatar
Sylvain Gugger committed
731
732
TF_RETURN_INTRODUCTION = r"""
    Returns:
733
734
735
        :class:`~{full_output_type}` or :obj:`tuple(tf.Tensor)`: A :class:`~{full_output_type}` or a tuple of
        :obj:`tf.Tensor` (if ``return_dict=False`` is passed or when ``config.return_dict=False``) comprising various
        elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs.
Sylvain Gugger's avatar
Sylvain Gugger committed
736
737
738
739

"""


740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
def _get_indent(t):
    """Returns the indentation in the first line of t"""
    search = re.search(r"^(\s*)\S", t)
    return "" if search is None else search.groups()[0]


def _convert_output_args_doc(output_args_doc):
    """Convert output_args_doc to display properly."""
    # Split output_arg_doc in blocks argument/description
    indent = _get_indent(output_args_doc)
    blocks = []
    current_block = ""
    for line in output_args_doc.split("\n"):
        # If the indent is the same as the beginning, the line is the name of new arg.
        if _get_indent(line) == indent:
            if len(current_block) > 0:
                blocks.append(current_block[:-1])
            current_block = f"{line}\n"
        else:
            # Otherwise it's part of the description of the current arg.
            # We need to remove 2 spaces to the indentation.
            current_block += f"{line[2:]}\n"
    blocks.append(current_block[:-1])

    # Format each block for proper rendering
    for i in range(len(blocks)):
        blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
        blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])

    return "\n".join(blocks)


772
773
774
775
776
777
778
779
780
781
782
783
784
def _prepare_output_docstrings(output_type, config_class):
    """
    Prepares the return part of the docstring using `output_type`.
    """
    docstrings = output_type.__doc__

    # Remove the head of the docstring to keep the list of args only
    lines = docstrings.split("\n")
    i = 0
    while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
        i += 1
    if i < len(lines):
        docstrings = "\n".join(lines[(i + 1) :])
785
        docstrings = _convert_output_args_doc(docstrings)
786
787

    # Add the return introduction
788
    full_output_type = f"{output_type.__module__}.{output_type.__name__}"
Sylvain Gugger's avatar
Sylvain Gugger committed
789
790
    intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION
    intro = intro.format(full_output_type=full_output_type, config_class=config_class)
791
792
793
    return intro + docstrings


794
795
796
797
798
799
800
PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import torch

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
801
        >>> model = {model_class}.from_pretrained('{checkpoint}')
802
803
804
805
806

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0)  # Batch size 1

        >>> outputs = model(**inputs, labels=labels)
807
808
        >>> loss = outputs.loss
        >>> logits = outputs.logits
809
810
811
812
813
814
815
816
817
"""

PT_QUESTION_ANSWERING_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import torch

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
818
        >>> model = {model_class}.from_pretrained('{checkpoint}')
819

820
821
        >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
        >>> inputs = tokenizer(question, text, return_tensors='pt')
822
823
824
825
        >>> start_positions = torch.tensor([1])
        >>> end_positions = torch.tensor([3])

        >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
826
        >>> loss = outputs.loss
827
828
        >>> start_scores = outputs.start_logits
        >>> end_scores = outputs.end_logits
829
830
831
832
833
834
835
836
837
"""

PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import torch

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
838
        >>> model = {model_class}.from_pretrained('{checkpoint}')
839
840
841
842

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        >>> outputs = model(**inputs, labels=labels)
843
844
        >>> loss = outputs.loss
        >>> logits = outputs.logits
845
846
847
848
849
850
851
852
853
"""

PT_MASKED_LM_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import torch

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
854
        >>> model = {model_class}.from_pretrained('{checkpoint}')
855

Sylvain Gugger's avatar
Sylvain Gugger committed
856
857
        >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
        >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
858

Sylvain Gugger's avatar
Sylvain Gugger committed
859
        >>> outputs = model(**inputs, labels=labels)
860
        >>> loss = outputs.loss
Sylvain Gugger's avatar
Sylvain Gugger committed
861
        >>> logits = outputs.logits
862
863
864
865
866
867
868
869
870
"""

PT_BASE_MODEL_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import torch

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
871
        >>> model = {model_class}.from_pretrained('{checkpoint}')
872
873
874
875

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

876
        >>> last_hidden_states = outputs.last_hidden_state
877
878
879
880
881
882
883
884
885
"""

PT_MULTIPLE_CHOICE_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import torch

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
886
        >>> model = {model_class}.from_pretrained('{checkpoint}')
887
888
889
890
891
892

        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> choice0 = "It is eaten with a fork and a knife."
        >>> choice1 = "It is eaten while held in the hand."
        >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1

893
        >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors='pt', padding=True)
894
895
896
        >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels)  # batch size is 1

        >>> # the linear classifier still needs to be trained
897
898
        >>> loss = outputs.loss
        >>> logits = outputs.logits
899
900
901
902
903
904
905
906
907
"""

PT_CAUSAL_LM_SAMPLE = r"""
    Example::

        >>> import torch
        >>> from transformers import {tokenizer_class}, {model_class}

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
Jungwhan's avatar
Jungwhan committed
908
        >>> model = {model_class}.from_pretrained('{checkpoint}')
909
910
911

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs, labels=inputs["input_ids"])
912
913
        >>> loss = outputs.loss
        >>> logits = outputs.logits
914
915
"""

916
917
918
919
920
921
922
923
924
925
926
PT_SAMPLE_DOCSTRINGS = {
    "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
    "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
    "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
    "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
    "MaskedLM": PT_MASKED_LM_SAMPLE,
    "LMHead": PT_CAUSAL_LM_SAMPLE,
    "BaseModel": PT_BASE_MODEL_SAMPLE,
}


927
928
929
930
931
932
933
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import tensorflow as tf

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
934
        >>> model = {model_class}.from_pretrained('{checkpoint}')
935
936
937
938
939
940

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
        >>> input_ids = inputs["input_ids"]
        >>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1

        >>> outputs = model(inputs)
Sylvain Gugger's avatar
Sylvain Gugger committed
941
942
        >>> loss = outputs.loss
        >>> logits = outputs.logits
943
944
945
946
947
948
949
950
951
"""

TF_QUESTION_ANSWERING_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import tensorflow as tf

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
952
        >>> model = {model_class}.from_pretrained('{checkpoint}')
953
954
955

        >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
        >>> input_dict = tokenizer(question, text, return_tensors='tf')
Sylvain Gugger's avatar
Sylvain Gugger committed
956
957
958
        >>> outputs = model(input_dict)
        >>> start_logits = outputs.start_logits
        >>> end_logits = outputs.end_logits
959
960

        >>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
Sylvain Gugger's avatar
Sylvain Gugger committed
961
        >>> answer = ' '.join(all_tokens[tf.math.argmax(start_logits, 1)[0] : tf.math.argmax(end_logits, 1)[0]+1])
962
963
964
965
966
967
968
969
970
"""

TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import tensorflow as tf

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
971
        >>> model = {model_class}.from_pretrained('{checkpoint}')
972
973
974
975
976

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
        >>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1

        >>> outputs = model(inputs)
Sylvain Gugger's avatar
Sylvain Gugger committed
977
978
        >>> loss = outputs.loss
        >>> logits = outputs.logits
979
980
981
982
"""

TF_MASKED_LM_SAMPLE = r"""
    Example::
Sylvain Gugger's avatar
Sylvain Gugger committed
983

984
985
986
987
        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import tensorflow as tf

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
988
        >>> model = {model_class}.from_pretrained('{checkpoint}')
989

Sylvain Gugger's avatar
Sylvain Gugger committed
990
991
        >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf")
        >>> inputs["labels"] = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
992

Sylvain Gugger's avatar
Sylvain Gugger committed
993
994
995
        >>> outputs = model(inputs)
        >>> loss = outputs.loss
        >>> logits = outputs.logits
996
997
998
999
1000
1001
1002
1003
1004
"""

TF_BASE_MODEL_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import tensorflow as tf

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
1005
        >>> model = {model_class}.from_pretrained('{checkpoint}')
1006
1007
1008
1009

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
        >>> outputs = model(inputs)

1010
        >>> last_hidden_states = outputs.last_hidden_state
1011
1012
1013
1014
1015
1016
1017
1018
1019
"""

TF_MULTIPLE_CHOICE_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import tensorflow as tf

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
1020
        >>> model = {model_class}.from_pretrained('{checkpoint}')
1021
1022
1023
1024
1025

        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> choice0 = "It is eaten with a fork and a knife."
        >>> choice1 = "It is eaten while held in the hand."

1026
        >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors='tf', padding=True)
1027
1028
1029
1030
        >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
        >>> outputs = model(inputs)  # batch size is 1

        >>> # the linear classifier still needs to be trained
Sylvain Gugger's avatar
Sylvain Gugger committed
1031
        >>> logits = outputs.logits
1032
1033
1034
1035
1036
1037
1038
1039
1040
"""

TF_CAUSAL_LM_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}
        >>> import tensorflow as tf

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
1041
        >>> model = {model_class}.from_pretrained('{checkpoint}')
1042
1043
1044

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
        >>> outputs = model(inputs)
Sylvain Gugger's avatar
Sylvain Gugger committed
1045
        >>> logits = outputs.logits
1046
1047
"""

1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
TF_SAMPLE_DOCSTRINGS = {
    "SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
    "QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
    "TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
    "MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
    "MaskedLM": TF_MASKED_LM_SAMPLE,
    "LMHead": TF_CAUSAL_LM_SAMPLE,
    "BaseModel": TF_BASE_MODEL_SAMPLE,
}


FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
        >>> model = {model_class}.from_pretrained('{checkpoint}')

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')

        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
"""

FLAX_QUESTION_ANSWERING_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
        >>> model = {model_class}.from_pretrained('{checkpoint}')

        >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
        >>> inputs = tokenizer(question, text, return_tensors='jax')

        >>> outputs = model(**inputs)
        >>> start_scores = outputs.start_logits
        >>> end_scores = outputs.end_logits
"""

FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
        >>> model = {model_class}.from_pretrained('{checkpoint}')

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')

1099
        >>> outputs = model(**inputs)
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
        >>> logits = outputs.logits
"""

FLAX_MASKED_LM_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
        >>> model = {model_class}.from_pretrained('{checkpoint}')

        >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors='jax')

        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
"""

FLAX_BASE_MODEL_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
        >>> model = {model_class}.from_pretrained('{checkpoint}')

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
        >>> outputs = model(**inputs)

        >>> last_hidden_states = outputs.last_hidden_state
"""

FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
        >>> model = {model_class}.from_pretrained('{checkpoint}')

        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> choice0 = "It is eaten with a fork and a knife."
        >>> choice1 = "It is eaten while held in the hand."

1143
        >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors='jax', padding=True)
1144
1145
1146
1147
1148
        >>> outputs = model(**{{k: v[None, :] for k,v in encoding.items()}})

        >>> logits = outputs.logits
"""

Suraj Patil's avatar
Suraj Patil committed
1149
1150
1151
1152
1153
1154
1155
1156
FLAX_CAUSAL_LM_SAMPLE = r"""
    Example::

        >>> from transformers import {tokenizer_class}, {model_class}

        >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
        >>> model = {model_class}.from_pretrained('{checkpoint}')

1157
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
1158
        >>> outputs = model(**inputs)
Suraj Patil's avatar
Suraj Patil committed
1159

1160
1161
        >>> # retrieve logts for next token
        >>> next_token_logits = outputs.logits[:, -1]
Suraj Patil's avatar
Suraj Patil committed
1162
1163
"""

1164
1165
1166
1167
1168
1169
1170
FLAX_SAMPLE_DOCSTRINGS = {
    "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
    "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
    "TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
    "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
    "MaskedLM": FLAX_MASKED_LM_SAMPLE,
    "BaseModel": FLAX_BASE_MODEL_SAMPLE,
Suraj Patil's avatar
Suraj Patil committed
1171
    "LMHead": FLAX_CAUSAL_LM_SAMPLE,
1172
1173
}

1174

Sylvain Gugger's avatar
Sylvain Gugger committed
1175
def add_code_sample_docstrings(
1176
    *docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None, model_cls=None
Sylvain Gugger's avatar
Sylvain Gugger committed
1177
):
1178
    def docstring_decorator(fn):
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        # model_class defaults to function's class if not specified otherwise
        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls

        if model_class[:2] == "TF":
            sample_docstrings = TF_SAMPLE_DOCSTRINGS
        elif model_class[:4] == "Flax":
            sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
        else:
            sample_docstrings = PT_SAMPLE_DOCSTRINGS

Sylvain Gugger's avatar
Sylvain Gugger committed
1189
        doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
1190
1191

        if "SequenceClassification" in model_class:
1192
            code_sample = sample_docstrings["SequenceClassification"]
1193
        elif "QuestionAnswering" in model_class:
1194
            code_sample = sample_docstrings["QuestionAnswering"]
1195
        elif "TokenClassification" in model_class:
1196
            code_sample = sample_docstrings["TokenClassification"]
1197
        elif "MultipleChoice" in model_class:
1198
            code_sample = sample_docstrings["MultipleChoice"]
Sylvain Gugger's avatar
Sylvain Gugger committed
1199
1200
        elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
            doc_kwargs["mask"] = "[MASK]" if mask is None else mask
1201
            code_sample = sample_docstrings["MaskedLM"]
Lysandre Debut's avatar
Lysandre Debut committed
1202
        elif "LMHead" in model_class or "CausalLM" in model_class:
1203
            code_sample = sample_docstrings["LMHead"]
1204
        elif "Model" in model_class or "Encoder" in model_class:
1205
            code_sample = sample_docstrings["BaseModel"]
1206
1207
1208
        else:
            raise ValueError(f"Docstring can't be built for model {model_class}")

1209
        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""
Sylvain Gugger's avatar
Sylvain Gugger committed
1210
        built_doc = code_sample.format(**doc_kwargs)
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc
        return fn

    return docstring_decorator


def replace_return_docstrings(output_type=None, config_class=None):
    def docstring_decorator(fn):
        docstrings = fn.__doc__
        lines = docstrings.split("\n")
        i = 0
        while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
            i += 1
        if i < len(lines):
            lines[i] = _prepare_output_docstrings(output_type, config_class)
            docstrings = "\n".join(lines)
        else:
            raise ValueError(
                f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, current docstring is:\n{docstrings}"
            )
        fn.__doc__ = docstrings
1232
1233
1234
1235
1236
        return fn

    return docstring_decorator


1237
1238
def is_remote_url(url_or_filename):
    parsed = urlparse(url_or_filename)
Julien Chaumond's avatar
Julien Chaumond committed
1239
    return parsed.scheme in ("http", "https")
1240

1241

1242
1243
1244
def hf_bucket_url(
    model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
) -> str:
Julien Chaumond's avatar
Julien Chaumond committed
1245
    """
Julien Chaumond's avatar
Julien Chaumond committed
1246
1247
    Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
    to Cloudfront (a Content Delivery Network, or CDN) for large files.
Sylvain Gugger's avatar
Sylvain Gugger committed
1248
1249

    Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
Julien Chaumond's avatar
Julien Chaumond committed
1250
1251
1252
1253
1254
1255
    bandwidth costs).

    Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
    because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
    in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
    can't ever be stale.
Sylvain Gugger's avatar
Sylvain Gugger committed
1256

Julien Chaumond's avatar
Julien Chaumond committed
1257
1258
1259
    In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
    its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
    are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
Julien Chaumond's avatar
Julien Chaumond committed
1260
    """
1261
1262
1263
    if subfolder is not None:
        filename = f"{subfolder}/{filename}"

Julien Chaumond's avatar
Julien Chaumond committed
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
    if mirror:
        endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
        legacy_format = "/" not in model_id
        if legacy_format:
            return f"{endpoint}/{model_id}-{filename}"
        else:
            return f"{endpoint}/{model_id}/{filename}"

    if revision is None:
        revision = "main"
    return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
1275
1276


Julien Chaumond's avatar
Julien Chaumond committed
1277
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
thomwolf's avatar
thomwolf committed
1278
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1279
1280
1281
1282
    Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
    delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
    identify it as a HDF5 file (see
    https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
thomwolf's avatar
thomwolf committed
1283
    """
1284
    url_bytes = url.encode("utf-8")
Julien Chaumond's avatar
Julien Chaumond committed
1285
    filename = sha256(url_bytes).hexdigest()
thomwolf's avatar
thomwolf committed
1286
1287

    if etag:
1288
        etag_bytes = etag.encode("utf-8")
Julien Chaumond's avatar
Julien Chaumond committed
1289
        filename += "." + sha256(etag_bytes).hexdigest()
thomwolf's avatar
thomwolf committed
1290

1291
1292
    if url.endswith(".h5"):
        filename += ".h5"
thomwolf's avatar
thomwolf committed
1293

thomwolf's avatar
thomwolf committed
1294
1295
1296
    return filename


thomwolf's avatar
thomwolf committed
1297
def filename_to_url(filename, cache_dir=None):
thomwolf's avatar
thomwolf committed
1298
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1299
1300
    Return the url and etag (which may be ``None``) stored for `filename`. Raise ``EnvironmentError`` if `filename` or
    its stored metadata do not exist.
thomwolf's avatar
thomwolf committed
1301
1302
    """
    if cache_dir is None:
1303
        cache_dir = TRANSFORMERS_CACHE
1304
    if isinstance(cache_dir, Path):
1305
        cache_dir = str(cache_dir)
thomwolf's avatar
thomwolf committed
1306
1307
1308

    cache_path = os.path.join(cache_dir, filename)
    if not os.path.exists(cache_path):
1309
        raise EnvironmentError(f"file {cache_path} not found")
thomwolf's avatar
thomwolf committed
1310

1311
    meta_path = cache_path + ".json"
thomwolf's avatar
thomwolf committed
1312
    if not os.path.exists(meta_path):
1313
        raise EnvironmentError(f"file {meta_path} not found")
thomwolf's avatar
thomwolf committed
1314

thomwolf's avatar
thomwolf committed
1315
    with open(meta_path, encoding="utf-8") as meta_file:
thomwolf's avatar
thomwolf committed
1316
        metadata = json.load(meta_file)
1317
1318
    url = metadata["url"]
    etag = metadata["etag"]
thomwolf's avatar
thomwolf committed
1319
1320
1321
1322

    return url, etag


1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
    """
    Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape
    :obj:`(model_url, etag, size_MB)`. Filenames in :obj:`cache_dir` are use to get the metadata for each model, only
    urls ending with `.bin` are added.

    Args:
        cache_dir (:obj:`Union[str, Path]`, `optional`):
            The cache directory to search for models within. Will default to the transformers cache if unset.

    Returns:
        List[Tuple]: List of tuples each with shape :obj:`(model_url, etag, size_MB)`
    """
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    elif isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    cached_models = []
    for file in os.listdir(cache_dir):
        if file.endswith(".json"):
            meta_path = os.path.join(cache_dir, file)
            with open(meta_path, encoding="utf-8") as meta_file:
                metadata = json.load(meta_file)
                url = metadata["url"]
                etag = metadata["etag"]
                if url.endswith(".bin"):
                    size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
                    cached_models.append((url, etag, size_MB))

    return cached_models


1356
def cached_path(
1357
1358
1359
1360
1361
    url_or_filename,
    cache_dir=None,
    force_download=False,
    proxies=None,
    resume_download=False,
Julien Chaumond's avatar
Julien Chaumond committed
1362
    user_agent: Union[Dict, str, None] = None,
1363
1364
    extract_compressed_file=False,
    force_extract=False,
1365
    use_auth_token: Union[bool, str, None] = None,
1366
    local_files_only=False,
1367
) -> Optional[str]:
thomwolf's avatar
thomwolf committed
1368
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1369
1370
1371
1372
    Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file
    and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and
    then return the path

1373
1374
    Args:
        cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
1375
1376
        force_download: if True, re-download the file even if it's already cached in the cache dir.
        resume_download: if True, resume the download if incompletely received file is found.
1377
        user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
1378
1379
        use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
            will get token from ~/.huggingface.
1380
1381
1382
        extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
            file in a folder along the archive.
        force_extract: if True when extract_compressed_file is True and the archive was already extracted,
1383
            re-extract the archive and override the folder where it was extracted.
1384
1385

    Return:
Julien Chaumond's avatar
Julien Chaumond committed
1386
1387
1388
1389
        Local path (string) of file or if networking is off, last version of file cached on disk.

    Raises:
        In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
thomwolf's avatar
thomwolf committed
1390
1391
    """
    if cache_dir is None:
1392
        cache_dir = TRANSFORMERS_CACHE
1393
    if isinstance(url_or_filename, Path):
1394
        url_or_filename = str(url_or_filename)
1395
    if isinstance(cache_dir, Path):
1396
        cache_dir = str(cache_dir)
thomwolf's avatar
thomwolf committed
1397

1398
1399
1400
1401
    if is_offline_mode() and not local_files_only:
        logger.info("Offline mode: forcing local_files_only=True")
        local_files_only = True

1402
    if is_remote_url(url_or_filename):
thomwolf's avatar
thomwolf committed
1403
        # URL, so get it from the cache (downloading if necessary)
1404
        output_path = get_from_cache(
1405
1406
1407
1408
1409
1410
            url_or_filename,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            user_agent=user_agent,
1411
            use_auth_token=use_auth_token,
1412
            local_files_only=local_files_only,
1413
        )
thomwolf's avatar
thomwolf committed
1414
1415
    elif os.path.exists(url_or_filename):
        # File, and it exists.
1416
        output_path = url_or_filename
1417
    elif urlparse(url_or_filename).scheme == "":
thomwolf's avatar
thomwolf committed
1418
        # File, but it doesn't exist.
1419
        raise EnvironmentError(f"file {url_or_filename} not found")
thomwolf's avatar
thomwolf committed
1420
1421
    else:
        # Something unknown
1422
        raise ValueError(f"unable to parse {url_or_filename} as a URL or as a local path")
thomwolf's avatar
thomwolf committed
1423

1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
    if extract_compressed_file:
        if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
            return output_path

        # Path where we extract compressed archives
        # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
        output_dir, output_file = os.path.split(output_path)
        output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
        output_path_extracted = os.path.join(output_dir, output_extract_dir_name)

        if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
            return output_path_extracted

        # Prevent parallel extractions
        lock_path = output_path + ".lock"
        with FileLock(lock_path):
            shutil.rmtree(output_path_extracted, ignore_errors=True)
            os.makedirs(output_path_extracted)
            if is_zipfile(output_path):
                with ZipFile(output_path, "r") as zip_file:
                    zip_file.extractall(output_path_extracted)
                    zip_file.close()
            elif tarfile.is_tarfile(output_path):
                tar_file = tarfile.open(output_path)
                tar_file.extractall(output_path_extracted)
                tar_file.close()
thomwolf's avatar
cleanup  
thomwolf committed
1450
            else:
1451
                raise EnvironmentError(f"Archive format of {output_path} could not be identified")
1452
1453
1454
1455
1456

        return output_path_extracted

    return output_path

thomwolf's avatar
thomwolf committed
1457

1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
def define_sagemaker_information():
    try:
        instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
        dlc_container_used = instance_data["Image"]
        dlc_tag = instance_data["Image"].split(":")[1]
    except Exception:
        dlc_container_used = None
        dlc_tag = None

    sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
    runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False
    account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None

    sagemaker_object = {
        "sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
        "sm_region": os.getenv("AWS_REGION", None),
        "sm_number_gpu": os.getenv("SM_NUM_GPUS", 0),
        "sm_number_cpu": os.getenv("SM_NUM_CPUS", 0),
        "sm_distributed_training": runs_distributed_training,
        "sm_deep_learning_container": dlc_container_used,
        "sm_deep_learning_container_tag": dlc_tag,
        "sm_account_id": account_id,
    }
    return sagemaker_object


Julien Chaumond's avatar
Julien Chaumond committed
1484
1485
1486
1487
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
    """
    Formats a user-agent string with basic info about a request.
    """
1488
    ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
1489
    if is_torch_available():
1490
        ua += f"; torch/{_torch_version}"
1491
    if is_tf_available():
1492
        ua += f"; tensorflow/{_tf_version}"
1493
1494
    if DISABLE_TELEMETRY:
        return ua + "; telemetry/off"
1495
1496
    if is_training_run_on_sagemaker():
        ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
1497
1498
1499
    # CI will set this value to True
    if os.environ.get("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
        ua += "; is_ci/true"
1500
    if isinstance(user_agent, dict):
1501
        ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
Aymeric Augustin's avatar
Aymeric Augustin committed
1502
    elif isinstance(user_agent, str):
1503
        ua += "; " + user_agent
Julien Chaumond's avatar
Julien Chaumond committed
1504
1505
1506
    return ua


1507
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
Julien Chaumond's avatar
Julien Chaumond committed
1508
    """
1509
    Download remote file. Do not gobble up errors.
Julien Chaumond's avatar
Julien Chaumond committed
1510
    """
1511
    headers = copy.deepcopy(headers)
1512
    if resume_size > 0:
1513
        headers["Range"] = f"bytes={resume_size}-"
Julien Chaumond's avatar
Julien Chaumond committed
1514
1515
1516
    r = requests.get(url, stream=True, proxies=proxies, headers=headers)
    r.raise_for_status()
    content_length = r.headers.get("Content-Length")
1517
    total = resume_size + int(content_length) if content_length is not None else None
1518
1519
1520
    progress = tqdm(
        unit="B",
        unit_scale=True,
1521
        unit_divisor=1024,
1522
1523
1524
        total=total,
        initial=resume_size,
        desc="Downloading",
Lysandre's avatar
Lysandre committed
1525
        disable=bool(logging.get_verbosity() == logging.NOTSET),
1526
    )
Julien Chaumond's avatar
Julien Chaumond committed
1527
    for chunk in r.iter_content(chunk_size=1024):
1528
        if chunk:  # filter out keep-alive new chunks
thomwolf's avatar
thomwolf committed
1529
1530
1531
1532
1533
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()


1534
def get_from_cache(
Julien Chaumond's avatar
Julien Chaumond committed
1535
    url: str,
1536
1537
1538
1539
1540
    cache_dir=None,
    force_download=False,
    proxies=None,
    etag_timeout=10,
    resume_download=False,
Julien Chaumond's avatar
Julien Chaumond committed
1541
    user_agent: Union[Dict, str, None] = None,
1542
    use_auth_token: Union[bool, str, None] = None,
1543
    local_files_only=False,
1544
) -> Optional[str]:
thomwolf's avatar
thomwolf committed
1545
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1546
1547
    Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
    path to the cached file.
1548
1549

    Return:
Julien Chaumond's avatar
Julien Chaumond committed
1550
1551
1552
1553
        Local path (string) of file or if networking is off, last version of file cached on disk.

    Raises:
        In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
thomwolf's avatar
thomwolf committed
1554
1555
    """
    if cache_dir is None:
1556
        cache_dir = TRANSFORMERS_CACHE
1557
    if isinstance(cache_dir, Path):
1558
        cache_dir = str(cache_dir)
thomwolf's avatar
thomwolf committed
1559

1560
    os.makedirs(cache_dir, exist_ok=True)
thomwolf's avatar
thomwolf committed
1561

1562
1563
    headers = {"user-agent": http_user_agent(user_agent)}
    if isinstance(use_auth_token, str):
1564
        headers["authorization"] = f"Bearer {use_auth_token}"
1565
1566
1567
1568
    elif use_auth_token:
        token = HfFolder.get_token()
        if token is None:
            raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
1569
        headers["authorization"] = f"Bearer {token}"
1570

Julien Chaumond's avatar
Julien Chaumond committed
1571
    url_to_download = url
1572
1573
    etag = None
    if not local_files_only:
Julien Chaumond's avatar
Julien Chaumond committed
1574
        try:
Julien Chaumond's avatar
Julien Chaumond committed
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
            r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
            r.raise_for_status()
            etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
            # We favor a custom header indicating the etag of the linked resource, and
            # we fallback to the regular etag header.
            # If we don't have any of those, raise an error.
            if etag is None:
                raise OSError(
                    "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
                )
            # In case of a redirect,
            # save an extra redirect on the request.get call,
            # and ensure we download the exact atomic version even if it changed
            # between the HEAD and the GET (unlikely, but hey).
            if 300 <= r.status_code <= 399:
                url_to_download = r.headers["Location"]
1591
1592
1593
        except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
            # Actually raise for those subclasses of ConnectionError
            raise
Julien Chaumond's avatar
Julien Chaumond committed
1594
        except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
1595
1596
            # Otherwise, our Internet connection is down.
            # etag is None
Julien Chaumond's avatar
Julien Chaumond committed
1597
            pass
thomwolf's avatar
thomwolf committed
1598
1599
1600
1601
1602
1603

    filename = url_to_filename(url, etag)

    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)

Julien Chaumond's avatar
Julien Chaumond committed
1604
    # etag is None == we don't have a connection or we passed local_files_only.
1605
    # try to get the last downloaded one
1606
1607
1608
1609
1610
1611
    if etag is None:
        if os.path.exists(cache_path):
            return cache_path
        else:
            matching_files = [
                file
1612
                for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
1613
1614
1615
1616
1617
                if not file.endswith(".json") and not file.endswith(".lock")
            ]
            if len(matching_files) > 0:
                return os.path.join(cache_dir, matching_files[-1])
            else:
1618
1619
1620
1621
                # If files cannot be found and local_files_only=True,
                # the models might've been found if local_files_only=False
                # Notify the user about that
                if local_files_only:
1622
                    raise FileNotFoundError(
1623
1624
1625
1626
                        "Cannot find the requested files in the cached path and outgoing traffic has been"
                        " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
                        " to False."
                    )
Julien Chaumond's avatar
Julien Chaumond committed
1627
1628
1629
1630
1631
                else:
                    raise ValueError(
                        "Connection error, and we cannot find the requested files in the cached path."
                        " Please try again or make sure your Internet connection is on."
                    )
1632
1633
1634
1635

    # From now on, etag is not None.
    if os.path.exists(cache_path) and not force_download:
        return cache_path
1636

1637
    # Prevent parallel downloads of the same file with a lock.
1638
    lock_path = cache_path + ".lock"
1639
1640
    with FileLock(lock_path):

Julien Chaumond's avatar
Julien Chaumond committed
1641
1642
1643
1644
1645
        # If the download just completed while the lock was activated.
        if os.path.exists(cache_path) and not force_download:
            # Even if returning early like here, the lock will be released.
            return cache_path

1646
        if resume_download:
1647
1648
            incomplete_path = cache_path + ".incomplete"

1649
            @contextmanager
Julien Chaumond's avatar
Julien Chaumond committed
1650
1651
            def _resumable_file_manager() -> "io.BufferedWriter":
                with open(incomplete_path, "ab") as f:
1652
                    yield f
1653

1654
1655
1656
1657
1658
            temp_file_manager = _resumable_file_manager
            if os.path.exists(incomplete_path):
                resume_size = os.stat(incomplete_path).st_size
            else:
                resume_size = 0
1659
        else:
Julien Chaumond's avatar
Julien Chaumond committed
1660
            temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
1661
            resume_size = 0
1662

1663
1664
1665
        # Download to temporary file, then copy to cache dir once finished.
        # Otherwise you get corrupt cache entries if the download gets interrupted.
        with temp_file_manager() as temp_file:
1666
            logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")
1667

1668
            http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
1669

1670
        logger.info(f"storing {url} in cache at {cache_path}")
1671
        os.replace(temp_file.name, cache_path)
1672

1673
1674
1675
1676
1677
        # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
        umask = os.umask(0o666)
        os.umask(umask)
        os.chmod(cache_path, 0o666 & ~umask)

1678
        logger.info(f"creating metadata file for {cache_path}")
1679
1680
1681
1682
        meta = {"url": url, "etag": etag}
        meta_path = cache_path + ".json"
        with open(meta_path, "w") as meta_file:
            json.dump(meta, meta_file)
thomwolf's avatar
thomwolf committed
1683
1684

    return cache_path
Julien Chaumond's avatar
Julien Chaumond committed
1685
1686


1687
1688
1689
1690
def get_list_of_files(
    path_or_repo: Union[str, os.PathLike],
    revision: Optional[str] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
1691
    local_files_only: bool = False,
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
) -> List[str]:
    """
    Gets the list of files inside :obj:`path_or_repo`.

    Args:
        path_or_repo (:obj:`str` or :obj:`os.PathLike`):
            Can be either the id of a repo on huggingface.co or a path to a `directory`.
        revision (:obj:`str`, `optional`, defaults to :obj:`"main"`):
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
            git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
            identifier allowed by git.
        use_auth_token (:obj:`str` or `bool`, `optional`):
            The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
            generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
1706
1707
        local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to only rely on local files and not to attempt to download any files.
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720

    Returns:
        :obj:`List[str]`: The list of files available in :obj:`path_or_repo`.
    """
    path_or_repo = str(path_or_repo)
    # If path_or_repo is a folder, we just return what is inside (subdirectories included).
    if os.path.isdir(path_or_repo):
        list_of_files = []
        for path, dir_names, file_names in os.walk(path_or_repo):
            list_of_files.extend([os.path.join(path, f) for f in file_names])
        return list_of_files

    # Can't grab the files if we are on offline mode.
1721
    if is_offline_mode() or local_files_only:
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
        return []

    # Otherwise we grab the token and use the model_info method.
    if isinstance(use_auth_token, str):
        token = use_auth_token
    elif use_auth_token is True:
        token = HfFolder.get_token()
    else:
        token = None
    model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info(
        path_or_repo, revision=revision, token=token
    )
    return [f.rfilename for f in model_info.siblings]


Julien Chaumond's avatar
Julien Chaumond committed
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
class cached_property(property):
    """
    Descriptor that mimics @property but caches output in member variable.

    From tensorflow_datasets

    Built-in in functools from Python 3.8.
    """

    def __get__(self, obj, objtype=None):
        # See docs.python.org/3/howto/descriptor.html#properties
        if obj is None:
            return self
        if self.fget is None:
            raise AttributeError("unreadable attribute")
        attr = "__cached_" + self.fget.__name__
        cached = getattr(obj, attr, None)
        if cached is None:
            cached = self.fget(obj)
            setattr(obj, attr, cached)
        return cached


def torch_required(func):
    # Chose a different decorator name than in tests so it's clear they are not the same.
    @wraps(func)
    def wrapper(*args, **kwargs):
        if is_torch_available():
            return func(*args, **kwargs)
        else:
            raise ImportError(f"Method `{func.__name__}` requires PyTorch.")

    return wrapper


def tf_required(func):
    # Chose a different decorator name than in tests so it's clear they are not the same.
    @wraps(func)
    def wrapper(*args, **kwargs):
        if is_tf_available():
            return func(*args, **kwargs)
        else:
            raise ImportError(f"Method `{func.__name__}` requires TF.")

    return wrapper
1782
1783


1784
1785
1786
1787
1788
1789
1790
1791
def is_torch_fx_proxy(x):
    if is_torch_fx_available():
        import torch.fx

        return isinstance(x, torch.fx.Proxy)
    return False


1792
def is_tensor(x):
1793
1794
1795
1796
    """
    Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
    :obj:`np.ndarray`.
    """
1797
1798
    if is_torch_fx_proxy(x):
        return True
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
    if is_torch_available():
        import torch

        if isinstance(x, torch.Tensor):
            return True
    if is_tf_available():
        import tensorflow as tf

        if isinstance(x, tf.Tensor):
            return True
1809
1810

    if is_flax_available():
Suraj Patil's avatar
Suraj Patil committed
1811
        import jax.numpy as jnp
1812
        from jax.core import Tracer
1813

Suraj Patil's avatar
Suraj Patil committed
1814
        if isinstance(x, (jnp.ndarray, Tracer)):
1815
1816
            return True

1817
1818
1819
    return isinstance(x, np.ndarray)


1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
def _is_numpy(x):
    return isinstance(x, np.ndarray)


def _is_torch(x):
    import torch

    return isinstance(x, torch.Tensor)


def _is_torch_device(x):
    import torch

    return isinstance(x, torch.device)


def _is_tensorflow(x):
    import tensorflow as tf

    return isinstance(x, tf.Tensor)


def _is_jax(x):
    import jax.numpy as jnp  # noqa: F811

    return isinstance(x, jnp.ndarray)


def to_py_obj(obj):
    """
    Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
    """
    if isinstance(obj, (dict, UserDict)):
        return {k: to_py_obj(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [to_py_obj(o) for o in obj]
    elif is_tf_available() and _is_tensorflow(obj):
        return obj.numpy().tolist()
    elif is_torch_available() and _is_torch(obj):
        return obj.detach().cpu().tolist()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj


1866
class ModelOutput(OrderedDict):
1867
    """
1868
    Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like
Sylvain Gugger's avatar
Sylvain Gugger committed
1869
1870
    a tuple) or strings (like a dictionary) that will ignore the ``None`` attributes. Otherwise behaves like a regular
    python dictionary.
1871
1872
1873
1874

    .. warning::
        You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple`
        method to convert it to a tuple before.
1875
1876
    """

1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
    def __post_init__(self):
        class_fields = fields(self)

        # Safety and consistency checks
        assert len(class_fields), f"{self.__class__.__name__} has no fields."
        assert all(
            field.default is None for field in class_fields[1:]
        ), f"{self.__class__.__name__} should not have more than one required field."

        first_field = getattr(self, class_fields[0].name)
        other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])

        if other_fields_are_none and not is_tensor(first_field):
1890
1891
            if isinstance(first_field, dict):
                iterator = first_field.items()
1892
                first_field_iterator = True
1893
1894
1895
1896
1897
1898
            else:
                try:
                    iterator = iter(first_field)
                    first_field_iterator = True
                except TypeError:
                    first_field_iterator = False
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912

            # if we provided an iterator as first field and the iterator is a (key, value) iterator
            # set the associated fields
            if first_field_iterator:
                for element in iterator:
                    if (
                        not isinstance(element, (list, tuple))
                        or not len(element) == 2
                        or not isinstance(element[0], str)
                    ):
                        break
                    setattr(self, element[0], element[1])
                    if element[1] is not None:
                        self[element[0]] = element[1]
1913
1914
            elif first_field is not None:
                self[class_fields[0].name] = first_field
1915
1916
1917
1918
1919
        else:
            for field in class_fields:
                v = getattr(self, field.name)
                if v is not None:
                    self[field.name] = v
1920

1921
1922
    def __delitem__(self, *args, **kwargs):
        raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
1923

1924
1925
    def setdefault(self, *args, **kwargs):
        raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
1926

1927
1928
1929
1930
1931
    def pop(self, *args, **kwargs):
        raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")

    def update(self, *args, **kwargs):
        raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
1932

1933
1934
1935
1936
1937
1938
    def __getitem__(self, k):
        if isinstance(k, str):
            inner_dict = {k: v for (k, v) in self.items()}
            return inner_dict[k]
        else:
            return self.to_tuple()[k]
1939

1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
    def __setattr__(self, name, value):
        if name in self.keys() and value is not None:
            # Don't call self.__setitem__ to avoid recursion errors
            super().__setitem__(name, value)
        super().__setattr__(name, value)

    def __setitem__(self, key, value):
        # Will raise a KeyException if needed
        super().__setitem__(key, value)
        # Don't call self.__setattr__ to avoid recursion errors
        super().__setattr__(key, value)

1952
1953
1954
1955
1956
    def to_tuple(self) -> Tuple[Any]:
        """
        Convert self to a tuple containing all the attributes/keys that are not ``None``.
        """
        return tuple(self[k] for k in self.keys())
1957
1958


1959
1960
1961
1962
1963
1964
1965
1966
class ExplicitEnum(Enum):
    """
    Enum with more explicit error message for missing values.
    """

    @classmethod
    def _missing_(cls, value):
        raise ValueError(
1967
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
        )


class PaddingStrategy(ExplicitEnum):
    """
    Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
    in an IDE.
    """

    LONGEST = "longest"
    MAX_LENGTH = "max_length"
    DO_NOT_PAD = "do_not_pad"


class TensorType(ExplicitEnum):
    """
    Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
    tab-completion in an IDE.
    """

    PYTORCH = "pt"
    TENSORFLOW = "tf"
    NUMPY = "np"
    JAX = "jax"


Sylvain Gugger's avatar
Sylvain Gugger committed
1994
class _LazyModule(ModuleType):
1995
1996
1997
1998
1999
2000
    """
    Module class that surfaces all objects but only performs associated imports when the objects are requested.
    """

    # Very heavily inspired by optuna.integration._IntegrationModule
    # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
2001
    def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
2002
2003
2004
2005
2006
2007
2008
2009
        super().__init__(name)
        self._modules = set(import_structure.keys())
        self._class_to_module = {}
        for key, values in import_structure.items():
            for value in values:
                self._class_to_module[value] = key
        # Needed for autocompletion in an IDE
        self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
Sylvain Gugger's avatar
Sylvain Gugger committed
2010
        self.__file__ = module_file
2011
        self.__spec__ = module_spec
Sylvain Gugger's avatar
Sylvain Gugger committed
2012
2013
2014
2015
        self.__path__ = [os.path.dirname(module_file)]
        self._objects = {} if extra_objects is None else extra_objects
        self._name = name
        self._import_structure = import_structure
2016
2017
2018
2019
2020
2021

    # Needed for autocompletion in an IDE
    def __dir__(self):
        return super().__dir__() + self.__all__

    def __getattr__(self, name: str) -> Any:
Sylvain Gugger's avatar
Sylvain Gugger committed
2022
2023
        if name in self._objects:
            return self._objects[name]
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
        if name in self._modules:
            value = self._get_module(name)
        elif name in self._class_to_module.keys():
            module = self._get_module(self._class_to_module[name])
            value = getattr(module, name)
        else:
            raise AttributeError(f"module {self.__name__} has no attribute {name}")

        setattr(self, name, value)
        return value

Sylvain Gugger's avatar
Sylvain Gugger committed
2035
2036
2037
2038
    def _get_module(self, module_name: str):
        return importlib.import_module("." + module_name, self.__name__)

    def __reduce__(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
2039
        return (self.__class__, (self._name, self.__file__, self._import_structure))
2040
2041
2042


def copy_func(f):
Patrick von Platen's avatar
Patrick von Platen committed
2043
    """Returns a copy of a function f."""
2044
2045
2046
2047
2048
    # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
    g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
    g = functools.update_wrapper(g, f)
    g.__kwdefaults__ = f.__kwdefaults__
    return g
Sylvain Gugger's avatar
Sylvain Gugger committed
2049
2050


2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
def is_local_clone(repo_path, repo_url):
    """
    Checks if the folder in `repo_path` is a local clone of `repo_url`.
    """
    # First double-check that `repo_path` is a git repo
    if not os.path.exists(os.path.join(repo_path, ".git")):
        return False
    test_git = subprocess.run("git branch".split(), cwd=repo_path)
    if test_git.returncode != 0:
        return False

    # Then look at its remotes
    remotes = subprocess.run(
        "git remote -v".split(),
        stderr=subprocess.PIPE,
        stdout=subprocess.PIPE,
        check=True,
        encoding="utf-8",
        cwd=repo_path,
    ).stdout

    return repo_url in remotes.split()


Sylvain Gugger's avatar
Sylvain Gugger committed
2075
2076
2077
2078
2079
2080
2081
class PushToHubMixin:
    """
    A Mixin containing the functionality to push a model or tokenizer to the hub.
    """

    def push_to_hub(
        self,
2082
        repo_path_or_name: Optional[str] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
2083
        repo_url: Optional[str] = None,
2084
        use_temp_dir: bool = False,
Sylvain Gugger's avatar
Sylvain Gugger committed
2085
2086
        commit_message: Optional[str] = None,
        organization: Optional[str] = None,
2087
        private: Optional[bool] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
2088
2089
2090
        use_auth_token: Optional[Union[bool, str]] = None,
    ) -> str:
        """
2091
        Upload the {object_files} to the 🤗 Model Hub while synchronizing a local clone of the repo in
2092
        :obj:`repo_path_or_name`.
Sylvain Gugger's avatar
Sylvain Gugger committed
2093
2094

        Parameters:
2095
            repo_path_or_name (:obj:`str`, `optional`):
2096
2097
2098
                Can either be a repository name for your {object} in the Hub or a path to a local folder (in which case
                the repository will have the name of that local folder). If not specified, will default to the name
                given by :obj:`repo_url` and a local directory with that name will be created.
Sylvain Gugger's avatar
Sylvain Gugger committed
2099
2100
2101
2102
            repo_url (:obj:`str`, `optional`):
                Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
                repository will be created in your namespace (unless you specify an :obj:`organization`) with
                :obj:`repo_name`.
2103
2104
2105
2106
            use_temp_dir (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clone the distant repo in a temporary directory or in :obj:`repo_path_or_name` inside
                the current working directory. This will slow things down if you are making changes in an existing repo
                since you will need to clone the repo before every push.
Sylvain Gugger's avatar
Sylvain Gugger committed
2107
            commit_message (:obj:`str`, `optional`):
2108
                Message to commit while pushing. Will default to :obj:`"add {object}"`.
Sylvain Gugger's avatar
Sylvain Gugger committed
2109
            organization (:obj:`str`, `optional`):
2110
                Organization in which you want to push your {object} (you must be a member of this organization).
Sylvain Gugger's avatar
Sylvain Gugger committed
2111
2112
2113
2114
2115
2116
2117
2118
2119
            private (:obj:`bool`, `optional`):
                Whether or not the repository created should be private (requires a paying subscription).
            use_auth_token (:obj:`bool` or :obj:`str`, `optional`):
                The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
                generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to
                :obj:`True` if :obj:`repo_url` is not specified.


        Returns:
2120
            :obj:`str`: The url of the commit of your {object} in the given repository.
2121
2122
2123

        Examples::

2124
            from transformers import {object_class}
2125

2126
            {object} = {object_class}.from_pretrained("bert-base-cased")
2127

2128
            # Push the {object} to your namespace with the name "my-finetuned-bert" and have a local clone in the
2129
            # `my-finetuned-bert` folder.
2130
            {object}.push_to_hub("my-finetuned-bert")
2131

2132
2133
            # Push the {object} to your namespace with the name "my-finetuned-bert" with no local clone.
            {object}.push_to_hub("my-finetuned-bert", use_temp_dir=True)
2134

2135
            # Push the {object} to an organization with the name "my-finetuned-bert" and have a local clone in the
2136
            # `my-finetuned-bert` folder.
2137
            {object}.push_to_hub("my-finetuned-bert", organization="huggingface")
2138
2139

            # Make a change to an existing repo that has been cloned locally in `my-finetuned-bert`.
2140
            {object}.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert")
Sylvain Gugger's avatar
Sylvain Gugger committed
2141
        """
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
        if use_temp_dir:
            # Make sure we use the right `repo_name` for the `repo_url` before replacing it.
            if repo_url is None:
                if use_auth_token is None:
                    use_auth_token = True
                repo_name = Path(repo_path_or_name).name
                repo_url = self._get_repo_url_from_name(
                    repo_name, organization=organization, private=private, use_auth_token=use_auth_token
                )
            repo_path_or_name = tempfile.mkdtemp()

        # Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.
        repo = self._create_or_get_repo(
            repo_path_or_name=repo_path_or_name,
            repo_url=repo_url,
            organization=organization,
            private=private,
            use_auth_token=use_auth_token,
        )
        # Save the files in the cloned repo
        self.save_pretrained(repo_path_or_name)
        # Commit and push!
        url = self._push_to_hub(repo, commit_message=commit_message)
Sylvain Gugger's avatar
Sylvain Gugger committed
2165

2166
2167
2168
2169
2170
2171
2172
2173
2174
        # Clean up! Clean up! Everybody everywhere!
        if use_temp_dir:
            shutil.rmtree(repo_path_or_name)

        return url

    @staticmethod
    def _get_repo_url_from_name(
        repo_name: str,
Sylvain Gugger's avatar
Sylvain Gugger committed
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
        organization: Optional[str] = None,
        private: bool = None,
        use_auth_token: Optional[Union[bool, str]] = None,
    ) -> str:
        if isinstance(use_auth_token, str):
            token = use_auth_token
        elif use_auth_token:
            token = HfFolder.get_token()
            if token is None:
                raise ValueError(
                    "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
                    "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
                    "token as the `use_auth_token` argument."
                )
        else:
            token = None

2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
        # Special provision for the test endpoint (CI)
        return HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo(
            token,
            repo_name,
            organization=organization,
            private=private,
            repo_type=None,
            exist_ok=True,
        )

    @classmethod
    def _create_or_get_repo(
        cls,
        repo_path_or_name: Optional[str] = None,
        repo_url: Optional[str] = None,
        organization: Optional[str] = None,
        private: bool = None,
        use_auth_token: Optional[Union[bool, str]] = None,
    ) -> Repository:
        if repo_path_or_name is None and repo_url is None:
            raise ValueError("You need to specify a `repo_path_or_name` or a `repo_url`.")

        if use_auth_token is None and repo_url is None:
            use_auth_token = True

        if repo_path_or_name is None:
            repo_path_or_name = repo_url.split("/")[-1]

        if repo_url is None and not os.path.exists(repo_path_or_name):
            repo_name = Path(repo_path_or_name).name
            repo_url = cls._get_repo_url_from_name(
                repo_name, organization=organization, private=private, use_auth_token=use_auth_token
Sylvain Gugger's avatar
Sylvain Gugger committed
2224
2225
            )

2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
        # Create a working directory if it does not exist.
        if not os.path.exists(repo_path_or_name):
            os.makedirs(repo_path_or_name)

        repo = Repository(repo_path_or_name, clone_from=repo_url, use_auth_token=use_auth_token)
        repo.git_pull()
        return repo

    @classmethod
    def _push_to_hub(cls, repo: Repository, commit_message: Optional[str] = None) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
2236
2237
2238
        if commit_message is None:
            if "Tokenizer" in cls.__name__:
                commit_message = "add tokenizer"
2239
            elif "Config" in cls.__name__:
Sylvain Gugger's avatar
Sylvain Gugger committed
2240
2241
2242
2243
                commit_message = "add config"
            else:
                commit_message = "add model"

2244
        return repo.push_to_hub(commit_message=commit_message)