file_utils.py 13.6 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
6

thomwolf's avatar
thomwolf committed
7

Aymeric Augustin's avatar
Aymeric Augustin committed
8
import fnmatch
thomwolf's avatar
thomwolf committed
9
import json
thomwolf's avatar
thomwolf committed
10
import logging
thomwolf's avatar
thomwolf committed
11
import os
Aymeric Augustin's avatar
Aymeric Augustin committed
12
import sys
thomwolf's avatar
thomwolf committed
13
import tempfile
Aymeric Augustin's avatar
Aymeric Augustin committed
14
from contextlib import contextmanager
15
from functools import partial, wraps
thomwolf's avatar
thomwolf committed
16
from hashlib import sha256
Aymeric Augustin's avatar
Aymeric Augustin committed
17
from urllib.parse import urlparse
thomwolf's avatar
thomwolf committed
18
19

import boto3
Aymeric Augustin's avatar
Aymeric Augustin committed
20
import requests
21
from botocore.config import Config
thomwolf's avatar
thomwolf committed
22
from botocore.exceptions import ClientError
Aymeric Augustin's avatar
Aymeric Augustin committed
23
from filelock import FileLock
24
from tqdm.auto import tqdm
Aymeric Augustin's avatar
Aymeric Augustin committed
25

26
from . import __version__
thomwolf's avatar
thomwolf committed
27

28

thomwolf's avatar
thomwolf committed
29
30
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name

thomwolf's avatar
thomwolf committed
31
try:
32
33
    os.environ.setdefault("USE_TORCH", "YES")
    if os.environ["USE_TORCH"].upper() in ("1", "ON", "YES"):
34
        import torch
35

36
37
        _torch_available = True  # pylint: disable=invalid-name
        logger.info("PyTorch version {} available.".format(torch.__version__))
38
39
40
    else:
        logger.info("USE_TORCH override through env variable, disabling PyTorch")
        _torch_available = False
thomwolf's avatar
thomwolf committed
41
42
43
except ImportError:
    _torch_available = False  # pylint: disable=invalid-name

Lysandre's avatar
Lysandre committed
44
try:
45
46
    os.environ.setdefault("USE_TF", "YES")
    if os.environ["USE_TF"].upper() in ("1", "ON", "YES"):
thomwolf's avatar
thomwolf committed
47
        import tensorflow as tf
48
49

        assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
thomwolf's avatar
thomwolf committed
50
51
52
53
54
        _tf_available = True  # pylint: disable=invalid-name
        logger.info("TensorFlow version {} available.".format(tf.__version__))
    else:
        logger.info("USE_TF override through env variable, disabling Tensorflow")
        _tf_available = False
Lysandre's avatar
Lysandre committed
55
56
except (ImportError, AssertionError):
    _tf_available = False  # pylint: disable=invalid-name
thomwolf's avatar
thomwolf committed
57

58
59
try:
    from torch.hub import _get_torch_home
60

61
62
63
    torch_cache_home = _get_torch_home()
except ImportError:
    torch_cache_home = os.path.expanduser(
64
65
66
        os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
    )
default_cache_path = os.path.join(torch_cache_home, "transformers")
67

thomwolf's avatar
thomwolf committed
68
69
70

try:
    from pathlib import Path
71

72
    PYTORCH_PRETRAINED_BERT_CACHE = Path(
73
74
        os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
    )
75
except (AttributeError, ImportError):
76
77
78
    PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
        "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
    )
79
80

PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE  # Kept for backward compatibility
81
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE  # Kept for backward compatibility
thomwolf's avatar
thomwolf committed
82

83
WEIGHTS_NAME = "pytorch_model.bin"
84
85
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
86
CONFIG_NAME = "config.json"
87
MODEL_CARD_NAME = "modelcard.json"
Thomas Wolf's avatar
Thomas Wolf committed
88

89
90
91
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]

92
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
93
CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
94

Thomas Wolf's avatar
Thomas Wolf committed
95

thomwolf's avatar
thomwolf committed
96
97
98
def is_torch_available():
    return _torch_available

99

thomwolf's avatar
thomwolf committed
100
101
102
def is_tf_available():
    return _tf_available

103

Aymeric Augustin's avatar
Aymeric Augustin committed
104
105
106
107
def add_start_docstrings(*docstr):
    def docstring_decorator(fn):
        fn.__doc__ = "".join(docstr) + fn.__doc__
        return fn
108

Aymeric Augustin's avatar
Aymeric Augustin committed
109
    return docstring_decorator
110

111

Aymeric Augustin's avatar
Aymeric Augustin committed
112
113
114
115
def add_end_docstrings(*docstr):
    def docstring_decorator(fn):
        fn.__doc__ = fn.__doc__ + "".join(docstr)
        return fn
116

Aymeric Augustin's avatar
Aymeric Augustin committed
117
    return docstring_decorator
thomwolf's avatar
thomwolf committed
118

119
120
121

def is_remote_url(url_or_filename):
    parsed = urlparse(url_or_filename)
122
123
    return parsed.scheme in ("http", "https", "s3")

124

125
126
def hf_bucket_url(identifier, postfix=None, cdn=False):
    endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
127
    if postfix is None:
128
        return "/".join((endpoint, identifier))
129
    else:
130
        return "/".join((endpoint, identifier, postfix))
131
132


thomwolf's avatar
thomwolf committed
133
def url_to_filename(url, etag=None):
thomwolf's avatar
thomwolf committed
134
135
136
137
    """
    Convert `url` into a hashed filename in a repeatable way.
    If `etag` is specified, append its hash to the url's, delimited
    by a period.
138
    If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
thomwolf's avatar
thomwolf committed
139
140
    so that TF 2.0 can identify it as a HDF5 file
    (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
thomwolf's avatar
thomwolf committed
141
    """
142
    url_bytes = url.encode("utf-8")
thomwolf's avatar
thomwolf committed
143
144
145
146
    url_hash = sha256(url_bytes)
    filename = url_hash.hexdigest()

    if etag:
147
        etag_bytes = etag.encode("utf-8")
thomwolf's avatar
thomwolf committed
148
        etag_hash = sha256(etag_bytes)
149
        filename += "." + etag_hash.hexdigest()
thomwolf's avatar
thomwolf committed
150

151
152
    if url.endswith(".h5"):
        filename += ".h5"
thomwolf's avatar
thomwolf committed
153

thomwolf's avatar
thomwolf committed
154
155
156
    return filename


thomwolf's avatar
thomwolf committed
157
def filename_to_url(filename, cache_dir=None):
thomwolf's avatar
thomwolf committed
158
159
    """
    Return the url and etag (which may be ``None``) stored for `filename`.
thomwolf's avatar
thomwolf committed
160
    Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
thomwolf's avatar
thomwolf committed
161
162
    """
    if cache_dir is None:
163
        cache_dir = TRANSFORMERS_CACHE
164
    if isinstance(cache_dir, Path):
165
        cache_dir = str(cache_dir)
thomwolf's avatar
thomwolf committed
166
167
168

    cache_path = os.path.join(cache_dir, filename)
    if not os.path.exists(cache_path):
thomwolf's avatar
thomwolf committed
169
        raise EnvironmentError("file {} not found".format(cache_path))
thomwolf's avatar
thomwolf committed
170

171
    meta_path = cache_path + ".json"
thomwolf's avatar
thomwolf committed
172
    if not os.path.exists(meta_path):
thomwolf's avatar
thomwolf committed
173
        raise EnvironmentError("file {} not found".format(meta_path))
thomwolf's avatar
thomwolf committed
174

thomwolf's avatar
thomwolf committed
175
    with open(meta_path, encoding="utf-8") as meta_file:
thomwolf's avatar
thomwolf committed
176
        metadata = json.load(meta_file)
177
178
    url = metadata["url"]
    etag = metadata["etag"]
thomwolf's avatar
thomwolf committed
179
180
181
182

    return url, etag


183
184
185
def cached_path(
    url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
):
thomwolf's avatar
thomwolf committed
186
187
188
189
190
    """
    Given something that might be a URL (or might be a local path),
    determine which. If it's a URL, download the file and cache it, and
    return the path to the cached file. If it's already a local path,
    make sure the file exists and then return the path.
191
192
193
    Args:
        cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
        force_download: if True, re-dowload the file even if it's already cached in the cache dir.
194
        resume_download: if True, resume the download if incompletly recieved file is found.
195
        user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
thomwolf's avatar
thomwolf committed
196
197
    """
    if cache_dir is None:
198
        cache_dir = TRANSFORMERS_CACHE
199
    if isinstance(url_or_filename, Path):
200
        url_or_filename = str(url_or_filename)
201
    if isinstance(cache_dir, Path):
202
        cache_dir = str(cache_dir)
thomwolf's avatar
thomwolf committed
203

204
    if is_remote_url(url_or_filename):
thomwolf's avatar
thomwolf committed
205
        # URL, so get it from the cache (downloading if necessary)
206
207
208
209
210
211
212
213
        return get_from_cache(
            url_or_filename,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            user_agent=user_agent,
        )
thomwolf's avatar
thomwolf committed
214
215
216
    elif os.path.exists(url_or_filename):
        # File, and it exists.
        return url_or_filename
217
    elif urlparse(url_or_filename).scheme == "":
thomwolf's avatar
thomwolf committed
218
        # File, but it doesn't exist.
thomwolf's avatar
thomwolf committed
219
        raise EnvironmentError("file {} not found".format(url_or_filename))
thomwolf's avatar
thomwolf committed
220
221
222
223
224
    else:
        # Something unknown
        raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))


thomwolf's avatar
thomwolf committed
225
def split_s3_path(url):
thomwolf's avatar
thomwolf committed
226
227
228
229
230
231
232
233
234
235
236
237
    """Split a full s3 path into the bucket name and path."""
    parsed = urlparse(url)
    if not parsed.netloc or not parsed.path:
        raise ValueError("bad s3 path {}".format(url))
    bucket_name = parsed.netloc
    s3_path = parsed.path
    # Remove '/' at beginning of path.
    if s3_path.startswith("/"):
        s3_path = s3_path[1:]
    return bucket_name, s3_path


thomwolf's avatar
thomwolf committed
238
def s3_request(func):
thomwolf's avatar
thomwolf committed
239
240
241
242
243
244
    """
    Wrapper function for s3 requests in order to create more helpful error
    messages.
    """

    @wraps(func)
thomwolf's avatar
thomwolf committed
245
    def wrapper(url, *args, **kwargs):
thomwolf's avatar
thomwolf committed
246
247
248
249
        try:
            return func(url, *args, **kwargs)
        except ClientError as exc:
            if int(exc.response["Error"]["Code"]) == 404:
thomwolf's avatar
thomwolf committed
250
                raise EnvironmentError("file {} not found".format(url))
thomwolf's avatar
thomwolf committed
251
252
253
254
255
256
257
            else:
                raise

    return wrapper


@s3_request
258
def s3_etag(url, proxies=None):
thomwolf's avatar
thomwolf committed
259
    """Check ETag on S3 object."""
260
    s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
thomwolf's avatar
thomwolf committed
261
262
263
264
265
266
    bucket_name, s3_path = split_s3_path(url)
    s3_object = s3_resource.Object(bucket_name, s3_path)
    return s3_object.e_tag


@s3_request
267
def s3_get(url, temp_file, proxies=None):
thomwolf's avatar
thomwolf committed
268
    """Pull a file directly from S3."""
269
    s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
thomwolf's avatar
thomwolf committed
270
271
272
273
    bucket_name, s3_path = split_s3_path(url)
    s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


274
275
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
    ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
276
277
278
279
    if is_torch_available():
        ua += "; torch/{}".format(torch.__version__)
    if is_tf_available():
        ua += "; tensorflow/{}".format(tf.__version__)
280
    if isinstance(user_agent, dict):
281
        ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
Aymeric Augustin's avatar
Aymeric Augustin committed
282
    elif isinstance(user_agent, str):
283
284
        ua += "; " + user_agent
    headers = {"user-agent": ua}
285
    if resume_size > 0:
286
        headers["Range"] = "bytes=%d-" % (resume_size,)
287
288
289
    response = requests.get(url, stream=True, proxies=proxies, headers=headers)
    if response.status_code == 416:  # Range not satisfiable
        return
290
    content_length = response.headers.get("Content-Length")
291
    total = resume_size + int(content_length) if content_length is not None else None
292
293
294
295
296
297
    progress = tqdm(
        unit="B",
        unit_scale=True,
        total=total,
        initial=resume_size,
        desc="Downloading",
thomwolf's avatar
thomwolf committed
298
        disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
299
    )
300
    for chunk in response.iter_content(chunk_size=1024):
301
        if chunk:  # filter out keep-alive new chunks
thomwolf's avatar
thomwolf committed
302
303
304
305
306
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()


307
308
309
def get_from_cache(
    url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
):
thomwolf's avatar
thomwolf committed
310
311
312
313
314
    """
    Given a URL, look for the corresponding dataset in the local cache.
    If it's not there, download it. Then return the path to the cached file.
    """
    if cache_dir is None:
315
        cache_dir = TRANSFORMERS_CACHE
316
    if isinstance(cache_dir, Path):
317
        cache_dir = str(cache_dir)
thomwolf's avatar
thomwolf committed
318

thomwolf's avatar
thomwolf committed
319
320
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
thomwolf's avatar
thomwolf committed
321
322
323

    # Get eTag to add to filename, if it exists.
    if url.startswith("s3://"):
324
        etag = s3_etag(url, proxies=proxies)
thomwolf's avatar
thomwolf committed
325
    else:
326
        try:
327
            response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
328
329
330
331
            if response.status_code != 200:
                etag = None
            else:
                etag = response.headers.get("ETag")
332
        except (EnvironmentError, requests.exceptions.Timeout):
333
            etag = None
thomwolf's avatar
thomwolf committed
334
335
336
337
338
339

    filename = url_to_filename(url, etag)

    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)

340
341
342
    # If we don't have a connection (etag is None) and can't identify the file
    # try to get the last downloaded one
    if not os.path.exists(cache_path) and etag is None:
343
344
        matching_files = [
            file
345
346
            for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
            if not file.endswith(".json") and not file.endswith(".lock")
347
        ]
348
349
350
        if matching_files:
            cache_path = os.path.join(cache_dir, matching_files[-1])

351
    # Prevent parallel downloads of the same file with a lock.
352
    lock_path = cache_path + ".lock"
353
354
355
    with FileLock(lock_path):

        if resume_download:
356
357
            incomplete_path = cache_path + ".incomplete"

358
359
            @contextmanager
            def _resumable_file_manager():
360
                with open(incomplete_path, "a+b") as f:
361
                    yield f
362

363
364
365
366
367
            temp_file_manager = _resumable_file_manager
            if os.path.exists(incomplete_path):
                resume_size = os.stat(incomplete_path).st_size
            else:
                resume_size = 0
368
        else:
369
            temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
370
            resume_size = 0
371
372
373
374
375

        if etag is not None and (not os.path.exists(cache_path) or force_download):
            # Download to temporary file, then copy to cache dir once finished.
            # Otherwise you get corrupt cache entries if the download gets interrupted.
            with temp_file_manager() as temp_file:
376
377
378
                logger.info(
                    "%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
                )
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

                # GET file object
                if url.startswith("s3://"):
                    if resume_download:
                        logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
                    s3_get(url, temp_file, proxies=proxies)
                else:
                    http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)

                # we are copying the file before closing it, so flush to avoid truncation
                temp_file.flush()

                logger.info("storing %s in cache at %s", url, cache_path)
                os.rename(temp_file.name, cache_path)

                logger.info("creating metadata file for %s", cache_path)
395
396
397
                meta = {"url": url, "etag": etag}
                meta_path = cache_path + ".json"
                with open(meta_path, "w") as meta_file:
398
                    json.dump(meta, meta_file)
thomwolf's avatar
thomwolf committed
399
400

    return cache_path