utils.py 15.7 KB
Newer Older
1
import bz2
2
3
4
5
import gzip
import hashlib
import itertools
import lzma
soumith's avatar
soumith committed
6
import os
soumith's avatar
soumith committed
7
import os.path
8
import pathlib
9
import re
10
import sys
11
import tarfile
12
13
import urllib
import urllib.error
14
import urllib.request
15
import warnings
16
17
18
import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
19

20
import requests
21
import torch
22
from torch.utils.model_zoo import tqdm
23

24
from .._internally_replaced_utils import (
25
26
27
28
    _download_file_from_remote_location,
    _is_remote_location_available,
)

29
30
31
USER_AGENT = "pytorch/vision"


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def _save_response_content(
    content: Iterator[bytes],
    destination: str,
    length: Optional[int] = None,
) -> None:
    with open(destination, "wb") as fh, tqdm(total=length) as pbar:
        for chunk in content:
            # filter out keep-alive new chunks
            if not chunk:
                continue

            fh.write(chunk)
            pbar.update(len(chunk))


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
    with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
        _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
50
51


52
def gen_bar_updater() -> Callable[[int, int, int], None]:
53
    warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
Francisco Massa's avatar
Francisco Massa committed
54
55
    pbar = tqdm(total=None)

56
    def bar_update(count, block_size, total_size):
Holger Kohr's avatar
Holger Kohr committed
57
58
59
60
        if pbar.total is None and total_size:
            pbar.total = total_size
        progress_bytes = count * block_size
        pbar.update(progress_bytes - pbar.n)
61
62

    return bar_update
soumith's avatar
soumith committed
63

soumith's avatar
soumith committed
64

65
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
66
67
68
69
    # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
    # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
    # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
    md5 = hashlib.md5(**dict(usedforsecurity=False) if sys.version_info >= (3, 9) else dict())
70
71
    with open(fpath, "rb") as f:
        for chunk in iter(lambda: f.read(chunk_size), b""):
72
73
74
75
            md5.update(chunk)
    return md5.hexdigest()


76
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
77
78
79
    return md5 == calculate_md5(fpath, **kwargs)


80
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
81
82
    if not os.path.isfile(fpath):
        return False
83
84
    if md5 is None:
        return True
85
    return check_md5(fpath, md5)
86
87


88
89
90
def _get_redirect_url(url: str, max_hops: int = 3) -> str:
    initial_url = url
    headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
91

92
93
94
95
    for _ in range(max_hops + 1):
        with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
            if response.url == url or response.url is None:
                return url
96

97
            url = response.url
98
    else:
99
100
101
        raise RecursionError(
            f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
        )
102
103


104
105
106
107
108
109
110
111
112
113
114
115
116
def _get_google_drive_file_id(url: str) -> Optional[str]:
    parts = urlparse(url)

    if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
        return None

    match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
    if match is None:
        return None

    return match.group("id")


117
118
119
def download_url(
    url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
120
121
122
123
124
    """Download a file from a url and place it in root.

    Args:
        url (str): URL to download file from
        root (str): Directory to place downloaded file in
125
126
        filename (str, optional): Name to save the file under. If None, use the basename of the URL
        md5 (str, optional): MD5 checksum of the download. If None, do not check
127
        max_redirect_hops (int, optional): Maximum number of redirect hops allowed
128
    """
129
    root = os.path.expanduser(root)
130
131
    if not filename:
        filename = os.path.basename(url)
132
133
    fpath = os.path.join(root, filename)

134
    os.makedirs(root, exist_ok=True)
135

136
    # check if file is already present locally
137
    if check_integrity(fpath, md5):
138
        print("Using downloaded and verified file: " + fpath)
139
140
        return

141
    if _is_remote_location_available():
142
        _download_file_from_remote_location(fpath, url)
143
144
145
146
147
148
149
150
151
152
153
    else:
        # expand redirect chain if needed
        url = _get_redirect_url(url, max_hops=max_redirect_hops)

        # check if file is located on Google Drive
        file_id = _get_google_drive_file_id(url)
        if file_id is not None:
            return download_file_from_google_drive(file_id, root, filename, md5)

        # download the file
        try:
154
            print("Downloading " + url + " to " + fpath)
155
            _urlretrieve(url, fpath)
156
        except (urllib.error.URLError, OSError) as e:  # type: ignore[attr-defined]
157
158
            if url[:5] == "https":
                url = url.replace("https:", "http:")
159
                print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
160
161
162
163
                _urlretrieve(url, fpath)
            else:
                raise e

164
165
166
    # check integrity of downloaded file
    if not check_integrity(fpath, md5):
        raise RuntimeError("File not found or corrupted.")
Sanyam Kapoor's avatar
Sanyam Kapoor committed
167
168


169
def list_dir(root: str, prefix: bool = False) -> List[str]:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
170
171
172
173
174
175
176
177
    """List all directories at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the directories found
    """
    root = os.path.expanduser(root)
178
    directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
Sanyam Kapoor's avatar
Sanyam Kapoor committed
179
180
181
182
183
    if prefix is True:
        directories = [os.path.join(root, d) for d in directories]
    return directories


184
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
185
186
187
188
189
190
191
192
193
194
    """List all files ending with a suffix at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
            It uses the Python "str.endswith" method and is passed directly
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the files found
    """
    root = os.path.expanduser(root)
195
    files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
Sanyam Kapoor's avatar
Sanyam Kapoor committed
196
197
198
    if prefix is True:
        files = [os.path.join(root, d) for d in files]
    return files
199
200


201
202
203
204
205
206
207
208
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
    content = response.iter_content(chunk_size)
    first_chunk = None
    # filter out keep-alive new chunks
    while not first_chunk:
        first_chunk = next(content)
    content = itertools.chain([first_chunk], content)

209
    try:
210
211
        match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
        api_response = match["api_response"] if match is not None else None
212
    except UnicodeDecodeError:
213
214
        api_response = None
    return api_response, content
215
216


217
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
218
219
220
221
222
223
224
225
    """Download a Google Drive file from  and place it in root.

    Args:
        file_id (str): id of file to be downloaded
        root (str): Directory to place downloaded file in
        filename (str, optional): Name to save the file under. If None, use the id of the file.
        md5 (str, optional): MD5 checksum of the download. If None, do not check
    """
226
227
    # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url

228
229
230
231
232
    root = os.path.expanduser(root)
    if not filename:
        filename = file_id
    fpath = os.path.join(root, filename)

233
234
    os.makedirs(root, exist_ok=True)

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    if check_integrity(fpath, md5):
        print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")

    url = "https://drive.google.com/uc"
    params = dict(id=file_id, export="download")
    with requests.Session() as session:
        response = session.get(url, params=params, stream=True)

        for key, value in response.cookies.items():
            if key.startswith("download_warning"):
                token = value
                break
        else:
            api_response, content = _extract_gdrive_api_response(response)
            token = "t" if api_response == "Virus scan warning" else None

        if token is not None:
            response = session.get(url, params=dict(params, confirm=token), stream=True)
            api_response, content = _extract_gdrive_api_response(response)

        if api_response == "Quota exceeded":
            raise RuntimeError(
257
258
259
260
                f"The daily quota of the file {filename} is exceeded and it "
                f"can't be downloaded. This is a limitation of Google Drive "
                f"and can only be overcome by trying again later."
            )
261

262
        _save_response_content(content, fpath)
263
264


265
266
267
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
    with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
        tar.extractall(to_path)
Ardalan's avatar
Ardalan committed
268
269


270
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
271
    ".bz2": zipfile.ZIP_BZIP2,
272
273
    ".xz": zipfile.ZIP_LZMA,
}
274
275


276
277
278
279
280
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
    with zipfile.ZipFile(
        from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
    ) as zip:
        zip.extractall(to_path)
281
282


283
284
285
286
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
    ".tar": _extract_tar,
    ".zip": _extract_zip,
}
287
288
289
290
291
292
293
294
295
296
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
    ".bz2": bz2.open,
    ".gz": gzip.open,
    ".xz": lzma.open,
}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
    ".tbz": (".tar", ".bz2"),
    ".tbz2": (".tar", ".bz2"),
    ".tgz": (".tar", ".gz"),
}
297
298


299
300
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
    """Detect the archive type and/or compression of a file.
301

302
303
    Args:
        file (str): the filename
304

305
306
    Returns:
        (tuple): tuple of suffix, archive type, and compression
307

308
309
310
    Raises:
        RuntimeError: if file has no suffix or suffix is not supported
    """
311
312
313
314
315
    suffixes = pathlib.Path(file).suffixes
    if not suffixes:
        raise RuntimeError(
            f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
        )
316
    suffix = suffixes[-1]
317
318

    # check if the suffix is a known alias
319
    if suffix in _FILE_TYPE_ALIASES:
320
321
322
        return (suffix, *_FILE_TYPE_ALIASES[suffix])

    # check if the suffix is an archive type
323
    if suffix in _ARCHIVE_EXTRACTORS:
324
325
326
        return suffix, suffix, None

    # check if the suffix is a compression
327
328
329
330
331
332
333
334
335
    if suffix in _COMPRESSED_FILE_OPENERS:
        # check for suffix hierarchy
        if len(suffixes) > 1:
            suffix2 = suffixes[-2]

            # check if the suffix2 is an archive type
            if suffix2 in _ARCHIVE_EXTRACTORS:
                return suffix2 + suffix, suffix2, suffix

336
337
        return suffix, None, suffix

338
339
    valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
    raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
    r"""Decompress a file.

    The compression is automatically detected from the file name.

    Args:
        from_path (str): Path to the file to be decompressed.
        to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
        remove_finished (bool): If ``True``, remove the file after the extraction.

    Returns:
        (str): Path to the decompressed file.
    """
    suffix, archive_type, compression = _detect_file_type(from_path)
    if not compression:
        raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")

359
    if to_path is None:
360
        to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
361

362
363
364
365
366
    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
    compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]

    with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
        wfh.write(rfh.read())
367
368

    if remove_finished:
369
370
        os.remove(from_path)

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    return to_path


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
    """Extract an archive.

    The archive type and a possible compression is automatically detected from the file name. If the file is compressed
    but not an archive the call is dispatched to :func:`decompress`.

    Args:
        from_path (str): Path to the file to be extracted.
        to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
            used.
        remove_finished (bool): If ``True``, remove the file after the extraction.

    Returns:
        (str): Path to the directory the file was extracted to.
    """
    if to_path is None:
        to_path = os.path.dirname(from_path)

    suffix, archive_type, compression = _detect_file_type(from_path)
    if not archive_type:
        return _decompress(
            from_path,
            os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
            remove_finished=remove_finished,
        )

    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
    extractor = _ARCHIVE_EXTRACTORS[archive_type]

    extractor(from_path, to_path, compression)
404
405
    if remove_finished:
        os.remove(from_path)
406
407
408

    return to_path

409

410
411
412
413
414
415
416
417
def download_and_extract_archive(
    url: str,
    download_root: str,
    extract_root: Optional[str] = None,
    filename: Optional[str] = None,
    md5: Optional[str] = None,
    remove_finished: bool = False,
) -> None:
418
419
420
421
422
    download_root = os.path.expanduser(download_root)
    if extract_root is None:
        extract_root = download_root
    if not filename:
        filename = os.path.basename(url)
423

424
    download_url(url, download_root, filename, md5)
425

426
    archive = os.path.join(download_root, filename)
427
    print(f"Extracting {archive} to {extract_root}")
428
    extract_archive(archive, extract_root, remove_finished)
429
430


431
def iterable_to_str(iterable: Iterable) -> str:
432
433
434
    return "'" + "', '".join([str(item) for item in iterable]) + "'"


435
436
437
438
T = TypeVar("T", str, bytes)


def verify_str_arg(
439
440
441
442
    value: T,
    arg: Optional[str] = None,
    valid_values: Iterable[T] = None,
    custom_msg: Optional[str] = None,
443
) -> T:
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    if not isinstance(value, torch._six.string_classes):
        if arg is None:
            msg = "Expected type str, but got type {type}."
        else:
            msg = "Expected type str for argument {arg}, but got type {type}."
        msg = msg.format(type=type(value), arg=arg)
        raise ValueError(msg)

    if valid_values is None:
        return value

    if value not in valid_values:
        if custom_msg is not None:
            msg = custom_msg
        else:
459
            msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
460
            msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
461
462
463
        raise ValueError(msg)

    return value