hub.py 46.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Hub utilities: utilities related to download and cache models
"""
import json
import os
19
import re
20
21
import shutil
import sys
22
import tempfile
23
import traceback
24
25
import warnings
from pathlib import Path
Sylvain Gugger's avatar
Sylvain Gugger committed
26
from typing import Dict, List, Optional, Tuple, Union
27
from urllib.parse import urlparse
28
29
from uuid import uuid4

30
import huggingface_hub
31
import requests
32
33
34
35
from huggingface_hub import (
    CommitOperationAdd,
    create_commit,
    create_repo,
36
    get_hf_file_metadata,
37
    hf_hub_download,
Sylvain Gugger's avatar
Sylvain Gugger committed
38
    hf_hub_url,
39
40
    whoami,
)
41
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
42
43
44
45
46
from huggingface_hub.utils import (
    EntryNotFoundError,
    LocalEntryNotFoundError,
    RepositoryNotFoundError,
    RevisionNotFoundError,
47
    build_hf_headers,
48
    hf_raise_for_status,
49
)
50
from requests.exceptions import HTTPError
51

52
from . import __version__, logging
53
from .generic import working_or_temp_dir
54
55
56
57
58
59
60
61
from .import_utils import (
    ENV_VARS_TRUE_VALUES,
    _tf_version,
    _torch_version,
    is_tf_available,
    is_torch_available,
    is_training_run_on_sagemaker,
)
62
from .logging import tqdm
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False


def is_offline_mode():
    return _is_offline_mode


torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser(
    os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
80
default_cache_path = os.path.join(hf_cache_home, "hub")
81
82
83
84
85
86
87
88
89
90

# Onetime move from the old location to the new one if no ENV variable has been set.
if (
    os.path.isdir(old_default_cache_path)
    and not os.path.isdir(default_cache_path)
    and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
    and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
    and "TRANSFORMERS_CACHE" not in os.environ
):
    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
91
92
93
94
95
        "In Transformers v4.0.0, the default path to cache downloaded models changed from"
        " '~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have"
        " overridden and '~/.cache/torch/transformers' is a directory that exists, we're moving it to"
        " '~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should"
        " only see this message once."
96
97
98
99
100
    )
    shutil.move(old_default_cache_path, default_cache_path)

PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
101
102
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", PYTORCH_TRANSFORMERS_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", HUGGINGFACE_HUB_CACHE)
103
104
105
106
107
108
109
110
111
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
SESSION_ID = uuid4().hex
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES

S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"

_staging_mode = os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
112
_default_endpoint = "https://hub-ci.huggingface.co" if _staging_mode else "https://huggingface.co"
113
114
115
116
117
118
119
120
121

HUGGINGFACE_CO_RESOLVE_ENDPOINT = _default_endpoint
if os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None) is not None:
    warnings.warn(
        "Using the environment variable `HUGGINGFACE_CO_RESOLVE_ENDPOINT` is deprecated and will be removed in "
        "Transformers v5. Use `HF_ENDPOINT` instead.",
        FutureWarning,
    )
    HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None)
122
123
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_RESOLVE_ENDPOINT)
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
Sylvain Gugger's avatar
Sylvain Gugger committed
124
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
125

126
127
128
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
_CACHED_NO_EXIST = object()

129

130
131
132
133
134
def is_remote_url(url_or_filename):
    parsed = urlparse(url_or_filename)
    return parsed.scheme in ("http", "https")


135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
    """
    Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
    etag, size_MB)`. Filenames in `cache_dir` are use to get the metadata for each model, only urls ending with *.bin*
    are added.

    Args:
        cache_dir (`Union[str, Path]`, *optional*):
            The cache directory to search for models within. Will default to the transformers cache if unset.

    Returns:
        List[Tuple]: List of tuples each with shape `(model_url, etag, size_MB)`
    """
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    elif isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)
152
153
    if not os.path.isdir(cache_dir):
        return []
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    cached_models = []
    for file in os.listdir(cache_dir):
        if file.endswith(".json"):
            meta_path = os.path.join(cache_dir, file)
            with open(meta_path, encoding="utf-8") as meta_file:
                metadata = json.load(meta_file)
                url = metadata["url"]
                etag = metadata["etag"]
                if url.endswith(".bin"):
                    size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
                    cached_models.append((url, etag, size_MB))

    return cached_models


def define_sagemaker_information():
    try:
        instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
        dlc_container_used = instance_data["Image"]
        dlc_tag = instance_data["Image"].split(":")[1]
    except Exception:
        dlc_container_used = None
        dlc_tag = None

    sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
    runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False
    account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None

    sagemaker_object = {
        "sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
        "sm_region": os.getenv("AWS_REGION", None),
        "sm_number_gpu": os.getenv("SM_NUM_GPUS", 0),
        "sm_number_cpu": os.getenv("SM_NUM_CPUS", 0),
        "sm_distributed_training": runs_distributed_training,
        "sm_deep_learning_container": dlc_container_used,
        "sm_deep_learning_container_tag": dlc_tag,
        "sm_account_id": account_id,
    }
    return sagemaker_object


def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
    """
    Formats a user-agent string with basic info about a request.
    """
    ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
    if is_torch_available():
        ua += f"; torch/{_torch_version}"
    if is_tf_available():
        ua += f"; tensorflow/{_tf_version}"
    if DISABLE_TELEMETRY:
        return ua + "; telemetry/off"
    if is_training_run_on_sagemaker():
        ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
    # CI will set this value to True
    if os.environ.get("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
        ua += "; is_ci/true"
    if isinstance(user_agent, dict):
        ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
    elif isinstance(user_agent, str):
        ua += "; " + user_agent
    return ua


219
220
221
222
223
224
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]):
    """
    Extracts the commit hash from a resolved filename toward a cache file.
    """
    if resolved_file is None or commit_hash is not None:
        return commit_hash
225
    resolved_file = str(Path(resolved_file).as_posix())
226
227
228
229
230
231
232
    search = re.search(r"snapshots/([^/]+)/", resolved_file)
    if search is None:
        return None
    commit_hash = search.groups()[0]
    return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None


233
234
235
236
237
def try_to_load_from_cache(
    repo_id: str,
    filename: str,
    cache_dir: Union[str, Path, None] = None,
    revision: Optional[str] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
238
    repo_type: Optional[str] = None,
239
) -> Optional[str]:
240
    """
241
242
243
    Explores the cache to return the latest cached file for a given revision if found.

    This function will not raise any exception if the file in not cached.
244
245

    Args:
246
247
248
249
250
251
        cache_dir (`str` or `os.PathLike`):
            The folder where the cached files lie.
        repo_id (`str`):
            The ID of the repo on huggingface.co.
        filename (`str`):
            The filename to look for inside `repo_id`.
252
253
254
        revision (`str`, *optional*):
            The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
            provided either.
Sylvain Gugger's avatar
Sylvain Gugger committed
255
256
        repo_type (`str`, *optional*):
            The type of the repo.
257
258
259
260
261
262
263

    Returns:
        `Optional[str]` or `_CACHED_NO_EXIST`:
            Will return `None` if the file was not cached. Otherwise:
            - The exact path to the cached file if it's found in the cache
            - A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
              cached.
264
    """
265
    if revision is None:
266
267
        revision = "main"

268
269
270
271
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE

    object_id = repo_id.replace("/", "--")
Sylvain Gugger's avatar
Sylvain Gugger committed
272
273
274
    if repo_type is None:
        repo_type = "model"
    repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
275
    if not os.path.isdir(repo_cache):
276
277
        # No cache for this model
        return None
278
    for subfolder in ["refs", "snapshots"]:
279
        if not os.path.isdir(os.path.join(repo_cache, subfolder)):
280
            return None
281

282
283
284
285
286
    # Resolve refs (for instance to convert main to the associated commit sha)
    cached_refs = os.listdir(os.path.join(repo_cache, "refs"))
    if revision in cached_refs:
        with open(os.path.join(repo_cache, "refs", revision)) as f:
            revision = f.read()
287

288
    if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
289
290
        return _CACHED_NO_EXIST

291
292
    cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
    if revision not in cached_shas:
293
294
295
        # No cache for this revision and we won't try to return a random revision
        return None

296
    cached_file = os.path.join(repo_cache, "snapshots", revision, filename)
297
298
299
300
301
    return cached_file if os.path.isfile(cached_file) else None


def cached_file(
    path_or_repo_id: Union[str, os.PathLike],
302
303
304
305
306
307
308
309
    filename: str,
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
310
    subfolder: str = "",
Sylvain Gugger's avatar
Sylvain Gugger committed
311
    repo_type: Optional[str] = None,
312
    user_agent: Optional[Union[str, Dict[str, str]]] = None,
313
314
315
    _raise_exceptions_for_missing_entries: bool = True,
    _raise_exceptions_for_connection_errors: bool = True,
    _commit_hash: Optional[str] = None,
316
317
318
319
320
):
    """
    Tries to locate a file in a local folder and repo, downloads and cache it if necessary.

    Args:
321
        path_or_repo_id (`str` or `os.PathLike`):
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            This can be either:

            - a string, the *model id* of a model repo on huggingface.co.
            - a path to a *directory* potentially containing the file.
        filename (`str`):
            The name of the file to locate in `path_or_repo`.
        cache_dir (`str` or `os.PathLike`, *optional*):
            Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
            cache should not be used.
        force_download (`bool`, *optional*, defaults to `False`):
            Whether or not to force to (re-)download the configuration files and override the cached versions if they
            exist.
        resume_download (`bool`, *optional*, defaults to `False`):
            Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
        proxies (`Dict[str, str]`, *optional*):
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
        use_auth_token (`str` or *bool*, *optional*):
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
341
            when running `huggingface-cli login` (stored in `~/.huggingface`).
342
343
344
345
346
347
        revision (`str`, *optional*, defaults to `"main"`):
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
            identifier allowed by git.
        local_files_only (`bool`, *optional*, defaults to `False`):
            If `True`, will only try to load the tokenizer configuration from local files.
348
349
350
        subfolder (`str`, *optional*, defaults to `""`):
            In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
            specify the folder name here.
Sylvain Gugger's avatar
Sylvain Gugger committed
351
352
        repo_type (`str`, *optional*):
            Specify the repo type (useful when downloading from a space for instance).
353
354
355
356
357
358
359
360

    <Tip>

    Passing `use_auth_token=True` is required when you want to use a private model.

    </Tip>

    Returns:
361
        `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
362
363
364
365

    Examples:

    ```python
366
367
    # Download a model weight from the Hub and cache it.
    model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
368
    ```"""
369
370
371
372
373
374
375
    # Private arguments
    #     _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
    #         None.
    #     _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
    #         None.
    #     _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
    #         a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
376
377
378
    if is_offline_mode() and not local_files_only:
        logger.info("Offline mode: forcing local_files_only=True")
        local_files_only = True
379
380
381
382
383
384
385
386
387
    if subfolder is None:
        subfolder = ""

    path_or_repo_id = str(path_or_repo_id)
    full_filename = os.path.join(subfolder, filename)
    if os.path.isdir(path_or_repo_id):
        resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
        if not os.path.isfile(resolved_file):
            if _raise_exceptions_for_missing_entries:
388
389
390
391
                raise EnvironmentError(
                    f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
                    f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
                )
392
393
394
            else:
                return None
        return resolved_file
395

396
397
398
399
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    if isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)
400

401
    if _commit_hash is not None and not force_download:
402
        # If the file is cached under that commit hash, we return it directly.
403
        resolved_file = try_to_load_from_cache(
Sylvain Gugger's avatar
Sylvain Gugger committed
404
            path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
405
        )
406
        if resolved_file is not None:
407
408
409
410
411
412
            if resolved_file is not _CACHED_NO_EXIST:
                return resolved_file
            elif not _raise_exceptions_for_missing_entries:
                return None
            else:
                raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
413

414
    user_agent = http_user_agent(user_agent)
415
416
    try:
        # Load from URL or cache if already cached
417
418
419
420
        resolved_file = hf_hub_download(
            path_or_repo_id,
            filename,
            subfolder=None if len(subfolder) == 0 else subfolder,
Sylvain Gugger's avatar
Sylvain Gugger committed
421
            repo_type=repo_type,
422
423
424
425
426
427
428
429
430
            revision=revision,
            cache_dir=cache_dir,
            user_agent=user_agent,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            use_auth_token=use_auth_token,
            local_files_only=local_files_only,
        )
431
432
433

    except RepositoryNotFoundError:
        raise EnvironmentError(
434
            f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
435
436
437
438
439
440
441
442
            "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
            "pass a token having permission to this repo with `use_auth_token` or log in with "
            "`huggingface-cli login` and pass `use_auth_token=True`."
        )
    except RevisionNotFoundError:
        raise EnvironmentError(
            f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
            "for this model name. Check the model page at "
443
444
            f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
        )
445
446
    except LocalEntryNotFoundError:
        # We try to see if we have a cached version (not up to date):
447
        resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
448
        if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
449
450
451
452
453
454
455
456
457
            return resolved_file
        if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors:
            return None
        raise EnvironmentError(
            f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
            f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
            f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
            " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
        )
458
459
460
461
462
463
464
465
466
467
468
    except EntryNotFoundError:
        if not _raise_exceptions_for_missing_entries:
            return None
        if revision is None:
            revision = "main"
        raise EnvironmentError(
            f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
            f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
        )
    except HTTPError as err:
        # First we try to see if we have a cached version (not up to date):
469
        resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
470
        if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
471
472
473
474
475
            return resolved_file
        if not _raise_exceptions_for_connection_errors:
            return None

        raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
476
477
478
479

    return resolved_file


480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def get_file_from_repo(
    path_or_repo: Union[str, os.PathLike],
    filename: str,
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
    subfolder: str = "",
):
    """
    Tries to locate a file in a local folder and repo, downloads and cache it if necessary.

    Args:
        path_or_repo (`str` or `os.PathLike`):
            This can be either:

            - a string, the *model id* of a model repo on huggingface.co.
            - a path to a *directory* potentially containing the file.
        filename (`str`):
            The name of the file to locate in `path_or_repo`.
        cache_dir (`str` or `os.PathLike`, *optional*):
            Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
            cache should not be used.
        force_download (`bool`, *optional*, defaults to `False`):
            Whether or not to force to (re-)download the configuration files and override the cached versions if they
            exist.
        resume_download (`bool`, *optional*, defaults to `False`):
            Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
        proxies (`Dict[str, str]`, *optional*):
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
        use_auth_token (`str` or *bool*, *optional*):
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
516
            when running `huggingface-cli login` (stored in `~/.huggingface`).
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
        revision (`str`, *optional*, defaults to `"main"`):
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
            identifier allowed by git.
        local_files_only (`bool`, *optional*, defaults to `False`):
            If `True`, will only try to load the tokenizer configuration from local files.
        subfolder (`str`, *optional*, defaults to `""`):
            In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
            specify the folder name here.

    <Tip>

    Passing `use_auth_token=True` is required when you want to use a private model.

    </Tip>

    Returns:
        `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
        file does not exist.

    Examples:

    ```python
    # Download a tokenizer configuration from huggingface.co and cache.
    tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
    # This model does not have a tokenizer config so the result will be None.
    tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
    ```"""
    return cached_file(
        path_or_repo_id=path_or_repo,
        filename=filename,
        cache_dir=cache_dir,
        force_download=force_download,
        resume_download=resume_download,
        proxies=proxies,
        use_auth_token=use_auth_token,
        revision=revision,
        local_files_only=local_files_only,
        subfolder=subfolder,
        _raise_exceptions_for_missing_entries=False,
        _raise_exceptions_for_connection_errors=False,
    )


561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def download_url(url, proxies=None):
    """
    Downloads a given url in a temporary file. This function is not safe to use in multiple processes. Its only use is
    for deprecated behavior allowing to download config/models with a single url instead of using the Hub.

    Args:
        url (`str`): The url of the file to download.
        proxies (`Dict[str, str]`, *optional*):
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.

    Returns:
        `str`: The location of the temporary file where the url was downloaded.
    """
    warnings.warn(
        f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
        " v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
        " that this is not compatible with the caching system (your file will be downloaded at each execution) or"
        " multiple processes (each process will download the file in a different temporary file)."
    )
    tmp_file = tempfile.mktemp()
    with open(tmp_file, "wb") as f:
        http_get(url, f, proxies=proxies)
    return tmp_file


587
588
589
590
591
592
593
594
def has_file(
    path_or_repo: Union[str, os.PathLike],
    filename: str,
    revision: Optional[str] = None,
    proxies: Optional[Dict[str, str]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
):
    """
595
    Checks if a repo contains a given file without downloading it. Works for remote repos and local folders.
596
597
598
599
600
601
602
603
604
605
606

    <Tip warning={false}>

    This function will raise an error if the repository `path_or_repo` is not valid or if `revision` does not exist for
    this repo, but will return False for regular connection errors.

    </Tip>
    """
    if os.path.isdir(path_or_repo):
        return os.path.isfile(os.path.join(path_or_repo, filename))

Sylvain Gugger's avatar
Sylvain Gugger committed
607
    url = hf_hub_url(path_or_repo, filename=filename, revision=revision)
608
    headers = build_hf_headers(use_auth_token=use_auth_token, user_agent=http_user_agent())
609
610
611

    r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
    try:
612
        hf_raise_for_status(r)
613
614
615
616
617
618
619
620
        return True
    except RepositoryNotFoundError as e:
        logger.error(e)
        raise EnvironmentError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.")
    except RevisionNotFoundError as e:
        logger.error(e)
        raise EnvironmentError(
            f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
621
            f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
622
623
624
625
626
627
628
629
630
631
632
        )
    except requests.HTTPError:
        # We return false for EntryNotFoundError (logical) as well as any connection error.
        return False


class PushToHubMixin:
    """
    A Mixin containing the functionality to push a model or tokenizer to the hub.
    """

633
    def _create_repo(
634
        self,
635
636
637
        repo_id: str,
        private: Optional[bool] = None,
        use_auth_token: Optional[Union[bool, str]] = None,
638
639
        repo_url: Optional[str] = None,
        organization: Optional[str] = None,
640
    ) -> str:
641
        """
642
643
        Create the repo if needed, cleans up repo_id with deprecated kwargs `repo_url` and `organization`, retrieves
        the token.
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        """
        if repo_url is not None:
            warnings.warn(
                "The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
                "instead."
            )
            repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
        if organization is not None:
            warnings.warn(
                "The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your "
                "organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)."
            )
            if not repo_id.startswith(organization):
                if "/" in repo_id:
                    repo_id = repo_id.split("/")[-1]
                repo_id = f"{organization}/{repo_id}"

661
        url = create_repo(repo_id=repo_id, token=use_auth_token, private=private, exist_ok=True)
662
663
664

        # If the namespace is not there, add it or `upload_file` will complain
        if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}":
665
666
            repo_id = get_full_repo_name(repo_id, token=use_auth_token)
        return repo_id
667
668
669
670
671
672
673
674
675
676
677
678
679

    def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
        """
        Returns the list of files with their last modification timestamp.
        """
        return {f: os.path.getmtime(os.path.join(working_dir, f)) for f in os.listdir(working_dir)}

    def _upload_modified_files(
        self,
        working_dir: Union[str, os.PathLike],
        repo_id: str,
        files_timestamps: Dict[str, float],
        commit_message: Optional[str] = None,
680
        token: Optional[Union[bool, str]] = None,
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        create_pr: bool = False,
    ):
        """
        Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
        """
        if commit_message is None:
            if "Model" in self.__class__.__name__:
                commit_message = "Upload model"
            elif "Config" in self.__class__.__name__:
                commit_message = "Upload config"
            elif "Tokenizer" in self.__class__.__name__:
                commit_message = "Upload tokenizer"
            elif "FeatureExtractor" in self.__class__.__name__:
                commit_message = "Upload feature extractor"
            elif "Processor" in self.__class__.__name__:
                commit_message = "Upload processor"
            else:
                commit_message = f"Upload {self.__class__.__name__}"
        modified_files = [
            f
            for f in os.listdir(working_dir)
            if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
        ]
        operations = []
        for file in modified_files:
            operations.append(CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file))
        logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
        return create_commit(
            repo_id=repo_id, operations=operations, commit_message=commit_message, token=token, create_pr=create_pr
        )

    def push_to_hub(
        self,
        repo_id: str,
        use_temp_dir: Optional[bool] = None,
        commit_message: Optional[str] = None,
717
718
        private: Optional[bool] = None,
        use_auth_token: Optional[Union[bool, str]] = None,
Arthur's avatar
Arthur committed
719
        max_shard_size: Optional[Union[int, str]] = "10GB",
720
        create_pr: bool = False,
721
        **deprecated_kwargs,
722
723
724
725
726
727
    ) -> str:
        """
        Upload the {object_files} to the 馃 Model Hub while synchronizing a local clone of the repo in
        `repo_path_or_name`.

        Parameters:
728
729
730
731
732
733
            repo_id (`str`):
                The name of the repository you want to push your {object} to. It should contain your organization name
                when pushing to a given organization.
            use_temp_dir (`bool`, *optional*):
                Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
                Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
734
            commit_message (`str`, *optional*):
735
                Message to commit while pushing. Will default to `"Upload {object}"`.
736
            private (`bool`, *optional*):
Christopher Akiki's avatar
Christopher Akiki committed
737
                Whether or not the repository created should be private.
738
739
            use_auth_token (`bool` or `str`, *optional*):
                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
740
741
                when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
                is not specified.
742
743
744
745
746
747
            max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
                Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
                will then be each of size lower than this size. If expressed as a string, needs to be digits followed
                by a unit (like `"5MB"`).
            create_pr (`bool`, *optional*, defaults to `False`):
                Whether or not to create a PR with the uploaded files or directly commit.
748
749
750
751
752
753
754
755

        Examples:

        ```python
        from transformers import {object_class}

        {object} = {object_class}.from_pretrained("bert-base-cased")

756
        # Push the {object} to your namespace with the name "my-finetuned-bert".
757
758
        {object}.push_to_hub("my-finetuned-bert")

759
760
        # Push the {object} to an organization with the name "my-finetuned-bert".
        {object}.push_to_hub("huggingface/my-finetuned-bert")
761
762
        ```
        """
763
764
765
766
767
768
769
770
771
772
773
774
775
        if "repo_path_or_name" in deprecated_kwargs:
            warnings.warn(
                "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
                "`repo_id` instead."
            )
            repo_id = deprecated_kwargs.pop("repo_path_or_name")
        # Deprecation warning will be sent after for repo_url and organization
        repo_url = deprecated_kwargs.pop("repo_url", None)
        organization = deprecated_kwargs.pop("organization", None)

        if os.path.isdir(repo_id):
            working_dir = repo_id
            repo_id = repo_id.split(os.path.sep)[-1]
776
        else:
777
778
            working_dir = repo_id.split("/")[-1]

779
        repo_id = self._create_repo(
780
            repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
781
782
        )

783
784
        if use_temp_dir is None:
            use_temp_dir = not os.path.isdir(working_dir)
785

786
787
        with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
            files_timestamps = self._get_files_timestamps(work_dir)
788

789
790
            # Save all files.
            self.save_pretrained(work_dir, max_shard_size=max_shard_size)
791

792
            return self._upload_modified_files(
793
794
795
796
797
798
                work_dir,
                repo_id,
                files_timestamps,
                commit_message=commit_message,
                token=use_auth_token,
                create_pr=create_pr,
799
800
801
802
803
804
805
806
807
            )


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
    if organization is None:
        username = whoami(token)["name"]
        return f"{username}/{model_id}"
    else:
        return f"{organization}/{model_id}"
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845


def send_example_telemetry(example_name, *example_args, framework="pytorch"):
    """
    Sends telemetry that helps tracking the examples use.

    Args:
        example_name (`str`): The name of the example.
        *example_args (dataclasses or `argparse.ArgumentParser`): The arguments to the script. This function will only
            try to extract the model and dataset name from those. Nothing else is tracked.
        framework (`str`, *optional*, defaults to `"pytorch"`): The framework for the example.
    """
    if is_offline_mode():
        return

    data = {"example": example_name, "framework": framework}
    for args in example_args:
        args_as_dict = {k: v for k, v in args.__dict__.items() if not k.startswith("_") and v is not None}
        if "model_name_or_path" in args_as_dict:
            model_name = args_as_dict["model_name_or_path"]
            # Filter out local paths
            if not os.path.isdir(model_name):
                data["model_name"] = args_as_dict["model_name_or_path"]
        if "dataset_name" in args_as_dict:
            data["dataset_name"] = args_as_dict["dataset_name"]
        elif "task_name" in args_as_dict:
            # Extract script name from the example_name
            script_name = example_name.replace("tf_", "").replace("flax_", "").replace("run_", "")
            script_name = script_name.replace("_no_trainer", "")
            data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"

    headers = {"user-agent": http_user_agent(data)}
    try:
        r = requests.head(HUGGINGFACE_CO_EXAMPLES_TELEMETRY, headers=headers)
        r.raise_for_status()
    except Exception:
        # We don't want to error in case of connection errors of any kind.
        pass
Arthur's avatar
Arthur committed
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891


def convert_file_size_to_int(size: Union[int, str]):
    """
    Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).

    Args:
        size (`int` or `str`): The size to convert. Will be directly returned if an `int`.

    Example:
    ```py
    >>> convert_file_size_to_int("1MiB")
    1048576
    ```
    """
    if isinstance(size, int):
        return size
    if size.upper().endswith("GIB"):
        return int(size[:-3]) * (2**30)
    if size.upper().endswith("MIB"):
        return int(size[:-3]) * (2**20)
    if size.upper().endswith("KIB"):
        return int(size[:-3]) * (2**10)
    if size.upper().endswith("GB"):
        int_size = int(size[:-2]) * (10**9)
        return int_size // 8 if size.endswith("b") else int_size
    if size.upper().endswith("MB"):
        int_size = int(size[:-2]) * (10**6)
        return int_size // 8 if size.endswith("b") else int_size
    if size.upper().endswith("KB"):
        int_size = int(size[:-2]) * (10**3)
        return int_size // 8 if size.endswith("b") else int_size
    raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")


def get_checkpoint_shard_files(
    pretrained_model_name_or_path,
    index_filename,
    cache_dir=None,
    force_download=False,
    proxies=None,
    resume_download=False,
    local_files_only=False,
    use_auth_token=None,
    user_agent=None,
    revision=None,
892
    subfolder="",
893
    _commit_hash=None,
Arthur's avatar
Arthur committed
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
):
    """
    For a given model:

    - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
      Hub
    - returns the list of paths to all the shards, as well as some metadata.

    For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
    index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
    """
    import json

    if not os.path.isfile(index_filename):
        raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")

    with open(index_filename, "r") as f:
        index = json.loads(f.read())

913
    shard_filenames = sorted(set(index["weight_map"].values()))
Arthur's avatar
Arthur committed
914
915
    sharded_metadata = index["metadata"]
    sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
Sylvain Gugger's avatar
Sylvain Gugger committed
916
    sharded_metadata["weight_map"] = index["weight_map"].copy()
Arthur's avatar
Arthur committed
917
918
919

    # First, let's deal with local folder.
    if os.path.isdir(pretrained_model_name_or_path):
920
        shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
Arthur's avatar
Arthur committed
921
922
923
924
        return shard_filenames, sharded_metadata

    # At this stage pretrained_model_name_or_path is a model identifier on the Hub
    cached_filenames = []
925
926
927
928
929
930
931
    # Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of
    # downloaded (if interrupted).
    last_shard = try_to_load_from_cache(
        pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash
    )
    show_progress_bar = last_shard is None or force_download
    for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
Arthur's avatar
Arthur committed
932
933
        try:
            # Load from URL
Sylvain Gugger's avatar
Sylvain Gugger committed
934
935
936
            cached_filename = cached_file(
                pretrained_model_name_or_path,
                shard_filename,
Arthur's avatar
Arthur committed
937
938
939
940
941
942
943
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                user_agent=user_agent,
Sylvain Gugger's avatar
Sylvain Gugger committed
944
945
                revision=revision,
                subfolder=subfolder,
946
                _commit_hash=_commit_hash,
Arthur's avatar
Arthur committed
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
            )
        # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
        # we don't have to catch them here.
        except EntryNotFoundError:
            raise EnvironmentError(
                f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
                "required according to the checkpoint index."
            )
        except HTTPError:
            raise EnvironmentError(
                f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
                " again after checking your internet connection."
            )

        cached_filenames.append(cached_filename)

    return cached_filenames, sharded_metadata
964
965
966
967
968
969
970
971
972
973
974
975
976


# All what is below is for conversion between old cache format and new cache format.


def get_all_cached_files(cache_dir=None):
    """
    Returns a list for all files cached with appropriate metadata.
    """
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    else:
        cache_dir = str(cache_dir)
977
978
    if not os.path.isdir(cache_dir):
        return []
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041

    cached_files = []
    for file in os.listdir(cache_dir):
        meta_path = os.path.join(cache_dir, f"{file}.json")
        if not os.path.isfile(meta_path):
            continue

        with open(meta_path, encoding="utf-8") as meta_file:
            metadata = json.load(meta_file)
            url = metadata["url"]
            etag = metadata["etag"].replace('"', "")
            cached_files.append({"file": file, "url": url, "etag": etag})

    return cached_files


def extract_info_from_url(url):
    """
    Extract repo_name, revision and filename from an url.
    """
    search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
    if search is None:
        return None
    repo, revision, filename = search.groups()
    cache_repo = "--".join(["models"] + repo.split("/"))
    return {"repo": cache_repo, "revision": revision, "filename": filename}


def clean_files_for(file):
    """
    Remove, if they exist, file, file.json and file.lock
    """
    for f in [file, f"{file}.json", f"{file}.lock"]:
        if os.path.isfile(f):
            os.remove(f)


def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
    """
    Move file to repo following the new huggingface hub cache organization.
    """
    os.makedirs(repo, exist_ok=True)

    # refs
    os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
    if revision != commit_hash:
        ref_path = os.path.join(repo, "refs", revision)
        with open(ref_path, "w") as f:
            f.write(commit_hash)

    # blobs
    os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
    blob_path = os.path.join(repo, "blobs", etag)
    shutil.move(file, blob_path)

    # snapshots
    os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
    os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
    pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
    huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
    clean_files_for(file)


1042
1043
1044
def move_cache(cache_dir=None, new_cache_dir=None, token=None):
    if new_cache_dir is None:
        new_cache_dir = TRANSFORMERS_CACHE
1045
    if cache_dir is None:
1046
1047
1048
1049
1050
1051
        # Migrate from old cache in .cache/huggingface/hub
        old_cache = Path(TRANSFORMERS_CACHE).parent / "transformers"
        if os.path.isdir(str(old_cache)):
            cache_dir = str(old_cache)
        else:
            cache_dir = new_cache_dir
1052
    cached_files = get_all_cached_files(cache_dir=cache_dir)
1053
    logger.info(f"Moving {len(cached_files)} files to the new cache system")
1054
1055
1056
1057
1058
1059

    hub_metadata = {}
    for file_info in tqdm(cached_files):
        url = file_info.pop("url")
        if url not in hub_metadata:
            try:
1060
                hub_metadata[url] = get_hf_file_metadata(url, token=token)
1061
1062
1063
            except requests.HTTPError:
                continue

1064
        etag, commit_hash = hub_metadata[url].etag, hub_metadata[url].commit_hash
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
        if etag is None or commit_hash is None:
            continue

        if file_info["etag"] != etag:
            # Cached file is not up to date, we just throw it as a new version will be downloaded anyway.
            clean_files_for(os.path.join(cache_dir, file_info["file"]))
            continue

        url_info = extract_info_from_url(url)
        if url_info is None:
            # Not a file from huggingface.co
            continue

1078
        repo = os.path.join(new_cache_dir, url_info["repo"])
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
        move_to_new_cache(
            file=os.path.join(cache_dir, file_info["file"]),
            repo=repo,
            filename=url_info["filename"],
            revision=url_info["revision"],
            etag=etag,
            commit_hash=commit_hash,
        )


cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
if not os.path.isfile(cache_version_file):
    cache_version = 0
else:
    with open(cache_version_file) as f:
1094
1095
1096
1097
        try:
            cache_version = int(f.read())
        except ValueError:
            cache_version = 0
1098

1099
cache_is_not_empty = os.path.isdir(TRANSFORMERS_CACHE) and len(os.listdir(TRANSFORMERS_CACHE)) > 0
1100

1101
if cache_version < 1 and cache_is_not_empty:
1102
    if is_offline_mode():
1103
        logger.warning(
1104
1105
1106
1107
1108
1109
            "You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
            "cache seems to be the one of a previous version. It is very likely that all your calls to any "
            "`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
            "your cache be updated automatically, then you can go back to offline mode."
        )
    else:
1110
        logger.warning(
1111
            "The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a "
1112
1113
1114
1115
            "one-time only operation. You can interrupt this and resume the migration later on by calling "
            "`transformers.utils.move_cache()`."
        )
    try:
1116
1117
1118
1119
1120
        if TRANSFORMERS_CACHE != default_cache_path:
            # Users set some env variable to customize cache storage
            move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE)
        else:
            move_cache()
1121
1122
1123
    except Exception as e:
        trace = "\n".join(traceback.format_tb(e.__traceback__))
        logger.error(
1124
1125
1126
            f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
            "file an issue at https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole "
            "message and we will do our best to help."
1127
1128
        )

1129
if cache_version < 1:
1130
1131
1132
1133
1134
    try:
        os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
        with open(cache_version_file, "w") as f:
            f.write("1")
    except Exception:
1135
        logger.warning(
1136
1137
1138
            f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
            "the environment variable TRANSFORMERS_CACHE to a writable directory."
        )