"tests/vscode:/vscode.git/clone" did not exist on "f33419261acc1a62cba1cc6b2cadc1a0f4841a0e"
file_utils.py 14.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
6

Aymeric Augustin's avatar
Aymeric Augustin committed
7
import fnmatch
thomwolf's avatar
thomwolf committed
8
import json
thomwolf's avatar
thomwolf committed
9
import logging
thomwolf's avatar
thomwolf committed
10
import os
Aymeric Augustin's avatar
Aymeric Augustin committed
11
import sys
thomwolf's avatar
thomwolf committed
12
import tempfile
Aymeric Augustin's avatar
Aymeric Augustin committed
13
from contextlib import contextmanager
14
from functools import partial, wraps
thomwolf's avatar
thomwolf committed
15
from hashlib import sha256
16
from typing import Optional
Aymeric Augustin's avatar
Aymeric Augustin committed
17
from urllib.parse import urlparse
thomwolf's avatar
thomwolf committed
18
19

import boto3
Aymeric Augustin's avatar
Aymeric Augustin committed
20
import requests
21
from botocore.config import Config
thomwolf's avatar
thomwolf committed
22
from botocore.exceptions import ClientError
Aymeric Augustin's avatar
Aymeric Augustin committed
23
from filelock import FileLock
24
from tqdm.auto import tqdm
Aymeric Augustin's avatar
Aymeric Augustin committed
25

26
from . import __version__
thomwolf's avatar
thomwolf committed
27

Lysandre's avatar
Lysandre committed
28

thomwolf's avatar
thomwolf committed
29
30
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name

thomwolf's avatar
thomwolf committed
31
try:
32
33
34
    USE_TF = os.environ.get("USE_TF", "AUTO").upper()
    USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
    if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
35
        import torch
36

37
38
        _torch_available = True  # pylint: disable=invalid-name
        logger.info("PyTorch version {} available.".format(torch.__version__))
39
    else:
40
        logger.info("Disabling PyTorch because USE_TF is set")
41
        _torch_available = False
thomwolf's avatar
thomwolf committed
42
43
44
except ImportError:
    _torch_available = False  # pylint: disable=invalid-name

Lysandre's avatar
Lysandre committed
45
try:
46
47
48
49
    USE_TF = os.environ.get("USE_TF", "AUTO").upper()
    USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()

    if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
thomwolf's avatar
thomwolf committed
50
        import tensorflow as tf
51
52

        assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
thomwolf's avatar
thomwolf committed
53
54
55
        _tf_available = True  # pylint: disable=invalid-name
        logger.info("TensorFlow version {} available.".format(tf.__version__))
    else:
56
        logger.info("Disabling Tensorflow because USE_TORCH is set")
thomwolf's avatar
thomwolf committed
57
        _tf_available = False
Lysandre's avatar
Lysandre committed
58
59
except (ImportError, AssertionError):
    _tf_available = False  # pylint: disable=invalid-name
thomwolf's avatar
thomwolf committed
60

61
62
try:
    from torch.hub import _get_torch_home
63

64
65
66
    torch_cache_home = _get_torch_home()
except ImportError:
    torch_cache_home = os.path.expanduser(
67
68
69
        os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
    )
default_cache_path = os.path.join(torch_cache_home, "transformers")
70

thomwolf's avatar
thomwolf committed
71
72
try:
    from pathlib import Path
73

74
    PYTORCH_PRETRAINED_BERT_CACHE = Path(
75
76
        os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
    )
77
except (AttributeError, ImportError):
78
79
80
    PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
        "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
    )
81
82

PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE  # Kept for backward compatibility
83
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE  # Kept for backward compatibility
thomwolf's avatar
thomwolf committed
84

85
WEIGHTS_NAME = "pytorch_model.bin"
86
87
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
88
CONFIG_NAME = "config.json"
89
MODEL_CARD_NAME = "modelcard.json"
Thomas Wolf's avatar
Thomas Wolf committed
90

91
92
93
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]

94
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
95
CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
96

Thomas Wolf's avatar
Thomas Wolf committed
97

thomwolf's avatar
thomwolf committed
98
99
100
def is_torch_available():
    return _torch_available

101

thomwolf's avatar
thomwolf committed
102
103
104
def is_tf_available():
    return _tf_available

105

Aymeric Augustin's avatar
Aymeric Augustin committed
106
107
def add_start_docstrings(*docstr):
    def docstring_decorator(fn):
108
109
110
111
112
113
114
115
116
117
        fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
        return fn

    return docstring_decorator


def add_start_docstrings_to_callable(*docstr):
    def docstring_decorator(fn):
        class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
        intro = "   The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
Lysandre's avatar
Lysandre committed
118
119
        note = r"""

120
121
122
123
    .. note::
        Although the recipe for forward pass needs to be defined within
        this function, one should call the :class:`Module` instance afterwards
        instead of this since the former takes care of running the
124
        pre and post processing steps while the latter silently ignores them.
125
126
        """
        fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
Aymeric Augustin's avatar
Aymeric Augustin committed
127
        return fn
128

Aymeric Augustin's avatar
Aymeric Augustin committed
129
    return docstring_decorator
130

131

Aymeric Augustin's avatar
Aymeric Augustin committed
132
133
134
135
def add_end_docstrings(*docstr):
    def docstring_decorator(fn):
        fn.__doc__ = fn.__doc__ + "".join(docstr)
        return fn
136

Aymeric Augustin's avatar
Aymeric Augustin committed
137
    return docstring_decorator
thomwolf's avatar
thomwolf committed
138

139
140
141

def is_remote_url(url_or_filename):
    parsed = urlparse(url_or_filename)
142
143
    return parsed.scheme in ("http", "https", "s3")

144

145
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
146
    endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
147
    if postfix is None:
148
        return "/".join((endpoint, identifier))
149
    else:
150
        return "/".join((endpoint, identifier, postfix))
151
152


thomwolf's avatar
thomwolf committed
153
def url_to_filename(url, etag=None):
thomwolf's avatar
thomwolf committed
154
155
156
157
    """
    Convert `url` into a hashed filename in a repeatable way.
    If `etag` is specified, append its hash to the url's, delimited
    by a period.
158
    If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
thomwolf's avatar
thomwolf committed
159
160
    so that TF 2.0 can identify it as a HDF5 file
    (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
thomwolf's avatar
thomwolf committed
161
    """
162
    url_bytes = url.encode("utf-8")
thomwolf's avatar
thomwolf committed
163
164
165
166
    url_hash = sha256(url_bytes)
    filename = url_hash.hexdigest()

    if etag:
167
        etag_bytes = etag.encode("utf-8")
thomwolf's avatar
thomwolf committed
168
        etag_hash = sha256(etag_bytes)
169
        filename += "." + etag_hash.hexdigest()
thomwolf's avatar
thomwolf committed
170

171
172
    if url.endswith(".h5"):
        filename += ".h5"
thomwolf's avatar
thomwolf committed
173

thomwolf's avatar
thomwolf committed
174
175
176
    return filename


thomwolf's avatar
thomwolf committed
177
def filename_to_url(filename, cache_dir=None):
thomwolf's avatar
thomwolf committed
178
179
    """
    Return the url and etag (which may be ``None``) stored for `filename`.
thomwolf's avatar
thomwolf committed
180
    Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
thomwolf's avatar
thomwolf committed
181
182
    """
    if cache_dir is None:
183
        cache_dir = TRANSFORMERS_CACHE
184
    if isinstance(cache_dir, Path):
185
        cache_dir = str(cache_dir)
thomwolf's avatar
thomwolf committed
186
187
188

    cache_path = os.path.join(cache_dir, filename)
    if not os.path.exists(cache_path):
thomwolf's avatar
thomwolf committed
189
        raise EnvironmentError("file {} not found".format(cache_path))
thomwolf's avatar
thomwolf committed
190

191
    meta_path = cache_path + ".json"
thomwolf's avatar
thomwolf committed
192
    if not os.path.exists(meta_path):
thomwolf's avatar
thomwolf committed
193
        raise EnvironmentError("file {} not found".format(meta_path))
thomwolf's avatar
thomwolf committed
194

thomwolf's avatar
thomwolf committed
195
    with open(meta_path, encoding="utf-8") as meta_file:
thomwolf's avatar
thomwolf committed
196
        metadata = json.load(meta_file)
197
198
    url = metadata["url"]
    etag = metadata["etag"]
thomwolf's avatar
thomwolf committed
199
200
201
202

    return url, etag


203
204
def cached_path(
    url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
205
) -> Optional[str]:
thomwolf's avatar
thomwolf committed
206
207
208
209
210
    """
    Given something that might be a URL (or might be a local path),
    determine which. If it's a URL, download the file and cache it, and
    return the path to the cached file. If it's already a local path,
    make sure the file exists and then return the path.
211
212
213
    Args:
        cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
        force_download: if True, re-dowload the file even if it's already cached in the cache dir.
214
        resume_download: if True, resume the download if incompletly recieved file is found.
215
        user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
216
217
218
219

    Return:
        None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
        Local path (string) otherwise
thomwolf's avatar
thomwolf committed
220
221
    """
    if cache_dir is None:
222
        cache_dir = TRANSFORMERS_CACHE
223
    if isinstance(url_or_filename, Path):
224
        url_or_filename = str(url_or_filename)
225
    if isinstance(cache_dir, Path):
226
        cache_dir = str(cache_dir)
thomwolf's avatar
thomwolf committed
227

228
    if is_remote_url(url_or_filename):
thomwolf's avatar
thomwolf committed
229
        # URL, so get it from the cache (downloading if necessary)
230
231
232
233
234
235
236
237
        return get_from_cache(
            url_or_filename,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            user_agent=user_agent,
        )
thomwolf's avatar
thomwolf committed
238
239
240
    elif os.path.exists(url_or_filename):
        # File, and it exists.
        return url_or_filename
241
    elif urlparse(url_or_filename).scheme == "":
thomwolf's avatar
thomwolf committed
242
        # File, but it doesn't exist.
thomwolf's avatar
thomwolf committed
243
        raise EnvironmentError("file {} not found".format(url_or_filename))
thomwolf's avatar
thomwolf committed
244
245
246
247
248
    else:
        # Something unknown
        raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))


thomwolf's avatar
thomwolf committed
249
def split_s3_path(url):
thomwolf's avatar
thomwolf committed
250
251
252
253
254
255
256
257
258
259
260
261
    """Split a full s3 path into the bucket name and path."""
    parsed = urlparse(url)
    if not parsed.netloc or not parsed.path:
        raise ValueError("bad s3 path {}".format(url))
    bucket_name = parsed.netloc
    s3_path = parsed.path
    # Remove '/' at beginning of path.
    if s3_path.startswith("/"):
        s3_path = s3_path[1:]
    return bucket_name, s3_path


thomwolf's avatar
thomwolf committed
262
def s3_request(func):
thomwolf's avatar
thomwolf committed
263
264
265
266
267
268
    """
    Wrapper function for s3 requests in order to create more helpful error
    messages.
    """

    @wraps(func)
thomwolf's avatar
thomwolf committed
269
    def wrapper(url, *args, **kwargs):
thomwolf's avatar
thomwolf committed
270
271
272
273
        try:
            return func(url, *args, **kwargs)
        except ClientError as exc:
            if int(exc.response["Error"]["Code"]) == 404:
thomwolf's avatar
thomwolf committed
274
                raise EnvironmentError("file {} not found".format(url))
thomwolf's avatar
thomwolf committed
275
276
277
278
279
280
281
            else:
                raise

    return wrapper


@s3_request
282
def s3_etag(url, proxies=None):
thomwolf's avatar
thomwolf committed
283
    """Check ETag on S3 object."""
284
    s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
thomwolf's avatar
thomwolf committed
285
286
287
288
289
290
    bucket_name, s3_path = split_s3_path(url)
    s3_object = s3_resource.Object(bucket_name, s3_path)
    return s3_object.e_tag


@s3_request
291
def s3_get(url, temp_file, proxies=None):
thomwolf's avatar
thomwolf committed
292
    """Pull a file directly from S3."""
293
    s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
thomwolf's avatar
thomwolf committed
294
295
296
297
    bucket_name, s3_path = split_s3_path(url)
    s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


298
299
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
    ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
300
301
302
303
    if is_torch_available():
        ua += "; torch/{}".format(torch.__version__)
    if is_tf_available():
        ua += "; tensorflow/{}".format(tf.__version__)
304
    if isinstance(user_agent, dict):
305
        ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
Aymeric Augustin's avatar
Aymeric Augustin committed
306
    elif isinstance(user_agent, str):
307
308
        ua += "; " + user_agent
    headers = {"user-agent": ua}
309
    if resume_size > 0:
310
        headers["Range"] = "bytes=%d-" % (resume_size,)
311
312
313
    response = requests.get(url, stream=True, proxies=proxies, headers=headers)
    if response.status_code == 416:  # Range not satisfiable
        return
314
    content_length = response.headers.get("Content-Length")
315
    total = resume_size + int(content_length) if content_length is not None else None
316
317
318
319
320
321
    progress = tqdm(
        unit="B",
        unit_scale=True,
        total=total,
        initial=resume_size,
        desc="Downloading",
thomwolf's avatar
thomwolf committed
322
        disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
323
    )
324
    for chunk in response.iter_content(chunk_size=1024):
325
        if chunk:  # filter out keep-alive new chunks
thomwolf's avatar
thomwolf committed
326
327
328
329
330
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()


331
332
def get_from_cache(
    url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
333
) -> Optional[str]:
thomwolf's avatar
thomwolf committed
334
    """
335
    Given a URL, look for the corresponding file in the local cache.
thomwolf's avatar
thomwolf committed
336
    If it's not there, download it. Then return the path to the cached file.
337
338
339
340

    Return:
        None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
        Local path (string) otherwise
thomwolf's avatar
thomwolf committed
341
342
    """
    if cache_dir is None:
343
        cache_dir = TRANSFORMERS_CACHE
344
    if isinstance(cache_dir, Path):
345
        cache_dir = str(cache_dir)
thomwolf's avatar
thomwolf committed
346

347
    os.makedirs(cache_dir, exist_ok=True)
thomwolf's avatar
thomwolf committed
348
349
350

    # Get eTag to add to filename, if it exists.
    if url.startswith("s3://"):
351
        etag = s3_etag(url, proxies=proxies)
thomwolf's avatar
thomwolf committed
352
    else:
353
        try:
354
            response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
355
356
357
358
            if response.status_code != 200:
                etag = None
            else:
                etag = response.headers.get("ETag")
359
        except (EnvironmentError, requests.exceptions.Timeout):
360
            etag = None
thomwolf's avatar
thomwolf committed
361
362
363
364
365
366

    filename = url_to_filename(url, etag)

    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)

367
    # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
368
    # try to get the last downloaded one
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    if etag is None:
        if os.path.exists(cache_path):
            return cache_path
        else:
            matching_files = [
                file
                for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
                if not file.endswith(".json") and not file.endswith(".lock")
            ]
            if len(matching_files) > 0:
                return os.path.join(cache_dir, matching_files[-1])
            else:
                return None

    # From now on, etag is not None.
    if os.path.exists(cache_path) and not force_download:
        return cache_path
386

387
    # Prevent parallel downloads of the same file with a lock.
388
    lock_path = cache_path + ".lock"
389
390
391
    with FileLock(lock_path):

        if resume_download:
392
393
            incomplete_path = cache_path + ".incomplete"

394
395
            @contextmanager
            def _resumable_file_manager():
396
                with open(incomplete_path, "a+b") as f:
397
                    yield f
398

399
400
401
402
403
            temp_file_manager = _resumable_file_manager
            if os.path.exists(incomplete_path):
                resume_size = os.stat(incomplete_path).st_size
            else:
                resume_size = 0
404
        else:
405
            temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
406
            resume_size = 0
407

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        # Download to temporary file, then copy to cache dir once finished.
        # Otherwise you get corrupt cache entries if the download gets interrupted.
        with temp_file_manager() as temp_file:
            logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)

            # GET file object
            if url.startswith("s3://"):
                if resume_download:
                    logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
                s3_get(url, temp_file, proxies=proxies)
            else:
                http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)

        logger.info("storing %s in cache at %s", url, cache_path)
        os.rename(temp_file.name, cache_path)

        logger.info("creating metadata file for %s", cache_path)
        meta = {"url": url, "etag": etag}
        meta_path = cache_path + ".json"
        with open(meta_path, "w") as meta_file:
            json.dump(meta, meta_file)
thomwolf's avatar
thomwolf committed
429
430

    return cache_path