utils.py 17.9 KB
Newer Older
1
import bz2
2
import contextlib
3
4
5
6
import gzip
import hashlib
import itertools
import lzma
soumith's avatar
soumith committed
7
import os
soumith's avatar
soumith committed
8
import os.path
9
import pathlib
10
import re
11
import sys
12
import tarfile
13
14
import urllib
import urllib.error
15
import urllib.request
16
import warnings
17
import zipfile
18
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
19
from urllib.parse import urlparse
20

21
import numpy as np
22
import requests
23
import torch
24
from torch.utils.model_zoo import tqdm
25

26
from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
27

28
29
30
USER_AGENT = "pytorch/vision"


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def _save_response_content(
    content: Iterator[bytes],
    destination: str,
    length: Optional[int] = None,
) -> None:
    with open(destination, "wb") as fh, tqdm(total=length) as pbar:
        for chunk in content:
            # filter out keep-alive new chunks
            if not chunk:
                continue

            fh.write(chunk)
            pbar.update(len(chunk))


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
    with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
        _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
49
50


51
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
52
53
54
    # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
    # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
    # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
55
56
57
58
    if sys.version_info >= (3, 9):
        md5 = hashlib.md5(usedforsecurity=False)
    else:
        md5 = hashlib.md5()
59
60
    with open(fpath, "rb") as f:
        for chunk in iter(lambda: f.read(chunk_size), b""):
61
62
63
64
            md5.update(chunk)
    return md5.hexdigest()


65
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
66
67
68
    return md5 == calculate_md5(fpath, **kwargs)


69
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
70
71
    if not os.path.isfile(fpath):
        return False
72
73
    if md5 is None:
        return True
74
    return check_md5(fpath, md5)
75
76


77
78
79
def _get_redirect_url(url: str, max_hops: int = 3) -> str:
    initial_url = url
    headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
80

81
82
83
84
    for _ in range(max_hops + 1):
        with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
            if response.url == url or response.url is None:
                return url
85

86
            url = response.url
87
    else:
88
89
90
        raise RecursionError(
            f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
        )
91
92


93
94
95
96
97
98
99
100
101
102
103
104
105
def _get_google_drive_file_id(url: str) -> Optional[str]:
    parts = urlparse(url)

    if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
        return None

    match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
    if match is None:
        return None

    return match.group("id")


106
107
108
def download_url(
    url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
109
110
111
112
113
    """Download a file from a url and place it in root.

    Args:
        url (str): URL to download file from
        root (str): Directory to place downloaded file in
114
115
        filename (str, optional): Name to save the file under. If None, use the basename of the URL
        md5 (str, optional): MD5 checksum of the download. If None, do not check
116
        max_redirect_hops (int, optional): Maximum number of redirect hops allowed
117
    """
118
    root = os.path.expanduser(root)
119
120
    if not filename:
        filename = os.path.basename(url)
121
122
    fpath = os.path.join(root, filename)

123
    os.makedirs(root, exist_ok=True)
124

125
    # check if file is already present locally
126
    if check_integrity(fpath, md5):
127
        print("Using downloaded and verified file: " + fpath)
128
129
        return

130
    if _is_remote_location_available():
131
        _download_file_from_remote_location(fpath, url)
132
133
134
135
136
137
138
139
140
141
142
    else:
        # expand redirect chain if needed
        url = _get_redirect_url(url, max_hops=max_redirect_hops)

        # check if file is located on Google Drive
        file_id = _get_google_drive_file_id(url)
        if file_id is not None:
            return download_file_from_google_drive(file_id, root, filename, md5)

        # download the file
        try:
143
            print("Downloading " + url + " to " + fpath)
144
            _urlretrieve(url, fpath)
145
        except (urllib.error.URLError, OSError) as e:  # type: ignore[attr-defined]
146
147
            if url[:5] == "https":
                url = url.replace("https:", "http:")
148
                print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
149
150
151
152
                _urlretrieve(url, fpath)
            else:
                raise e

153
154
155
    # check integrity of downloaded file
    if not check_integrity(fpath, md5):
        raise RuntimeError("File not found or corrupted.")
Sanyam Kapoor's avatar
Sanyam Kapoor committed
156
157


158
def list_dir(root: str, prefix: bool = False) -> List[str]:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
159
160
161
162
163
164
165
166
    """List all directories at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the directories found
    """
    root = os.path.expanduser(root)
167
    directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
Sanyam Kapoor's avatar
Sanyam Kapoor committed
168
169
170
171
172
    if prefix is True:
        directories = [os.path.join(root, d) for d in directories]
    return directories


173
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
174
175
176
177
178
179
180
181
182
183
    """List all files ending with a suffix at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
            It uses the Python "str.endswith" method and is passed directly
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the files found
    """
    root = os.path.expanduser(root)
184
    files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
Sanyam Kapoor's avatar
Sanyam Kapoor committed
185
186
187
    if prefix is True:
        files = [os.path.join(root, d) for d in files]
    return files
188
189


190
191
192
193
194
195
196
197
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
    content = response.iter_content(chunk_size)
    first_chunk = None
    # filter out keep-alive new chunks
    while not first_chunk:
        first_chunk = next(content)
    content = itertools.chain([first_chunk], content)

198
    try:
199
200
        match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
        api_response = match["api_response"] if match is not None else None
201
    except UnicodeDecodeError:
202
203
        api_response = None
    return api_response, content
204
205


206
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
207
208
209
210
211
212
213
214
    """Download a Google Drive file from  and place it in root.

    Args:
        file_id (str): id of file to be downloaded
        root (str): Directory to place downloaded file in
        filename (str, optional): Name to save the file under. If None, use the id of the file.
        md5 (str, optional): MD5 checksum of the download. If None, do not check
    """
215
216
    # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url

217
218
219
220
221
    root = os.path.expanduser(root)
    if not filename:
        filename = file_id
    fpath = os.path.join(root, filename)

222
223
    os.makedirs(root, exist_ok=True)

224
225
    if check_integrity(fpath, md5):
        print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
226
        return
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

    url = "https://drive.google.com/uc"
    params = dict(id=file_id, export="download")
    with requests.Session() as session:
        response = session.get(url, params=params, stream=True)

        for key, value in response.cookies.items():
            if key.startswith("download_warning"):
                token = value
                break
        else:
            api_response, content = _extract_gdrive_api_response(response)
            token = "t" if api_response == "Virus scan warning" else None

        if token is not None:
            response = session.get(url, params=dict(params, confirm=token), stream=True)
            api_response, content = _extract_gdrive_api_response(response)

        if api_response == "Quota exceeded":
            raise RuntimeError(
247
248
249
250
                f"The daily quota of the file {filename} is exceeded and it "
                f"can't be downloaded. This is a limitation of Google Drive "
                f"and can only be overcome by trying again later."
            )
251

252
        _save_response_content(content, fpath)
253

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
    if os.stat(fpath).st_size < 10 * 1024:
        with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
            text = fh.read()
            # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
            if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
                warnings.warn(
                    f"We detected some HTML elements in the downloaded file. "
                    f"This most likely means that the download triggered an unhandled API response by GDrive. "
                    f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
                    f"the response:\n\n{text}"
                )

    if md5 and not check_md5(fpath, md5):
        raise RuntimeError(
            f"The MD5 checksum of the download file {fpath} does not match the one on record."
            f"Please delete the file and try again. "
            f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
        )

274

275
276
277
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
    with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
        tar.extractall(to_path)
Ardalan's avatar
Ardalan committed
278
279


280
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
281
    ".bz2": zipfile.ZIP_BZIP2,
282
283
    ".xz": zipfile.ZIP_LZMA,
}
284
285


286
287
288
289
290
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
    with zipfile.ZipFile(
        from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
    ) as zip:
        zip.extractall(to_path)
291
292


293
294
295
296
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
    ".tar": _extract_tar,
    ".zip": _extract_zip,
}
297
298
299
300
301
302
303
304
305
306
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
    ".bz2": bz2.open,
    ".gz": gzip.open,
    ".xz": lzma.open,
}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
    ".tbz": (".tar", ".bz2"),
    ".tbz2": (".tar", ".bz2"),
    ".tgz": (".tar", ".gz"),
}
307
308


309
310
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
    """Detect the archive type and/or compression of a file.
311

312
313
    Args:
        file (str): the filename
314

315
316
    Returns:
        (tuple): tuple of suffix, archive type, and compression
317

318
319
320
    Raises:
        RuntimeError: if file has no suffix or suffix is not supported
    """
321
322
323
324
325
    suffixes = pathlib.Path(file).suffixes
    if not suffixes:
        raise RuntimeError(
            f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
        )
326
    suffix = suffixes[-1]
327
328

    # check if the suffix is a known alias
329
    if suffix in _FILE_TYPE_ALIASES:
330
331
332
        return (suffix, *_FILE_TYPE_ALIASES[suffix])

    # check if the suffix is an archive type
333
    if suffix in _ARCHIVE_EXTRACTORS:
334
335
336
        return suffix, suffix, None

    # check if the suffix is a compression
337
338
339
340
341
342
343
344
345
    if suffix in _COMPRESSED_FILE_OPENERS:
        # check for suffix hierarchy
        if len(suffixes) > 1:
            suffix2 = suffixes[-2]

            # check if the suffix2 is an archive type
            if suffix2 in _ARCHIVE_EXTRACTORS:
                return suffix2 + suffix, suffix2, suffix

346
347
        return suffix, None, suffix

348
349
    valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
    raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
    r"""Decompress a file.

    The compression is automatically detected from the file name.

    Args:
        from_path (str): Path to the file to be decompressed.
        to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
        remove_finished (bool): If ``True``, remove the file after the extraction.

    Returns:
        (str): Path to the decompressed file.
    """
    suffix, archive_type, compression = _detect_file_type(from_path)
    if not compression:
        raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")

369
    if to_path is None:
370
        to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
371

372
373
374
375
376
    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
    compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]

    with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
        wfh.write(rfh.read())
377
378

    if remove_finished:
379
380
        os.remove(from_path)

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    return to_path


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
    """Extract an archive.

    The archive type and a possible compression is automatically detected from the file name. If the file is compressed
    but not an archive the call is dispatched to :func:`decompress`.

    Args:
        from_path (str): Path to the file to be extracted.
        to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
            used.
        remove_finished (bool): If ``True``, remove the file after the extraction.

    Returns:
        (str): Path to the directory the file was extracted to.
    """
    if to_path is None:
        to_path = os.path.dirname(from_path)

    suffix, archive_type, compression = _detect_file_type(from_path)
    if not archive_type:
        return _decompress(
            from_path,
            os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
            remove_finished=remove_finished,
        )

    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
    extractor = _ARCHIVE_EXTRACTORS[archive_type]

    extractor(from_path, to_path, compression)
414
415
    if remove_finished:
        os.remove(from_path)
416
417
418

    return to_path

419

420
421
422
423
424
425
426
427
def download_and_extract_archive(
    url: str,
    download_root: str,
    extract_root: Optional[str] = None,
    filename: Optional[str] = None,
    md5: Optional[str] = None,
    remove_finished: bool = False,
) -> None:
428
429
430
431
432
    download_root = os.path.expanduser(download_root)
    if extract_root is None:
        extract_root = download_root
    if not filename:
        filename = os.path.basename(url)
433

434
    download_url(url, download_root, filename, md5)
435

436
    archive = os.path.join(download_root, filename)
437
    print(f"Extracting {archive} to {extract_root}")
438
    extract_archive(archive, extract_root, remove_finished)
439
440


441
def iterable_to_str(iterable: Iterable) -> str:
442
443
444
    return "'" + "', '".join([str(item) for item in iterable]) + "'"


445
446
447
448
T = TypeVar("T", str, bytes)


def verify_str_arg(
449
450
    value: T,
    arg: Optional[str] = None,
451
    valid_values: Optional[Iterable[T]] = None,
452
    custom_msg: Optional[str] = None,
453
) -> T:
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
    if not isinstance(value, torch._six.string_classes):
        if arg is None:
            msg = "Expected type str, but got type {type}."
        else:
            msg = "Expected type str for argument {arg}, but got type {type}."
        msg = msg.format(type=type(value), arg=arg)
        raise ValueError(msg)

    if valid_values is None:
        return value

    if value not in valid_values:
        if custom_msg is not None:
            msg = custom_msg
        else:
469
            msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
470
            msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
471
472
473
        raise ValueError(msg)

    return value
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509


def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
    """Read file in .pfm format. Might contain either 1 or 3 channels of data.

    Args:
        file_name (str): Path to the file.
        slice_channels (int): Number of channels to slice out of the file.
            Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
    """

    with open(file_name, "rb") as f:
        header = f.readline().rstrip()
        if header not in [b"PF", b"Pf"]:
            raise ValueError("Invalid PFM file")

        dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
        if not dim_match:
            raise Exception("Malformed PFM header.")
        w, h = (int(dim) for dim in dim_match.groups())

        scale = float(f.readline().rstrip())
        if scale < 0:  # little-endian
            endian = "<"
            scale = -scale
        else:
            endian = ">"  # big-endian

        data = np.fromfile(f, dtype=endian + "f")

    pfm_channels = 3 if header == b"PF" else 1

    data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
    data = np.flip(data, axis=1)  # flip on h dimension
    data = data[:slice_channels, :, :]
    return data.astype(np.float32)
Philip Meier's avatar
Philip Meier committed
510
511
512
513
514
515


def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
    return (
        t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
    )