utils.py 16.3 KB
Newer Older
1
import bz2
2
3
4
import gzip
import hashlib
import lzma
soumith's avatar
soumith committed
5
import os
soumith's avatar
soumith committed
6
import os.path
7
import pathlib
8
import re
9
import sys
10
import tarfile
11
12
import urllib
import urllib.error
13
14
import urllib.request
import zipfile
15
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
16
from urllib.parse import urlparse
17

18
import numpy as np
19
import torch
20
from torch.utils.model_zoo import tqdm
21

22
from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
23

24
25
26
USER_AGENT = "pytorch/vision"


27
28
def _save_response_content(
    content: Iterator[bytes],
29
    destination: Union[str, pathlib.Path],
30
31
32
33
34
35
36
37
38
39
40
41
    length: Optional[int] = None,
) -> None:
    with open(destination, "wb") as fh, tqdm(total=length) as pbar:
        for chunk in content:
            # filter out keep-alive new chunks
            if not chunk:
                continue

            fh.write(chunk)
            pbar.update(len(chunk))


42
def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
43
44
    with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
        _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
45
46


47
def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:
48
49
50
    # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
    # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
    # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
51
52
53
54
    if sys.version_info >= (3, 9):
        md5 = hashlib.md5(usedforsecurity=False)
    else:
        md5 = hashlib.md5()
55
    with open(fpath, "rb") as f:
56
        while chunk := f.read(chunk_size):
57
58
59
60
            md5.update(chunk)
    return md5.hexdigest()


61
def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
62
63
64
    return md5 == calculate_md5(fpath, **kwargs)


65
def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
66
67
    if not os.path.isfile(fpath):
        return False
68
69
    if md5 is None:
        return True
70
    return check_md5(fpath, md5)
71
72


73
74
75
def _get_redirect_url(url: str, max_hops: int = 3) -> str:
    initial_url = url
    headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
76

77
78
79
80
    for _ in range(max_hops + 1):
        with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
            if response.url == url or response.url is None:
                return url
81

82
            url = response.url
83
    else:
84
85
86
        raise RecursionError(
            f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
        )
87
88


89
90
91
92
93
94
95
96
97
98
99
100
101
def _get_google_drive_file_id(url: str) -> Optional[str]:
    parts = urlparse(url)

    if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
        return None

    match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
    if match is None:
        return None

    return match.group("id")


102
def download_url(
103
104
    url: str,
    root: Union[str, pathlib.Path],
105
    filename: Optional[Union[str, pathlib.Path]] = None,
106
107
    md5: Optional[str] = None,
    max_redirect_hops: int = 3,
108
) -> None:
109
110
111
112
113
    """Download a file from a url and place it in root.

    Args:
        url (str): URL to download file from
        root (str): Directory to place downloaded file in
114
115
        filename (str, optional): Name to save the file under. If None, use the basename of the URL
        md5 (str, optional): MD5 checksum of the download. If None, do not check
116
        max_redirect_hops (int, optional): Maximum number of redirect hops allowed
117
    """
118
    root = os.path.expanduser(root)
119
120
    if not filename:
        filename = os.path.basename(url)
121
    fpath = os.fspath(os.path.join(root, filename))
122

123
    os.makedirs(root, exist_ok=True)
124

125
    # check if file is already present locally
126
    if check_integrity(fpath, md5):
127
        print("Using downloaded and verified file: " + fpath)
128
129
        return

130
    if _is_remote_location_available():
131
        _download_file_from_remote_location(fpath, url)
132
133
134
135
136
137
138
139
140
141
142
    else:
        # expand redirect chain if needed
        url = _get_redirect_url(url, max_hops=max_redirect_hops)

        # check if file is located on Google Drive
        file_id = _get_google_drive_file_id(url)
        if file_id is not None:
            return download_file_from_google_drive(file_id, root, filename, md5)

        # download the file
        try:
143
            print("Downloading " + url + " to " + fpath)
144
            _urlretrieve(url, fpath)
145
        except (urllib.error.URLError, OSError) as e:  # type: ignore[attr-defined]
146
147
            if url[:5] == "https":
                url = url.replace("https:", "http:")
148
                print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
149
150
151
152
                _urlretrieve(url, fpath)
            else:
                raise e

153
154
155
    # check integrity of downloaded file
    if not check_integrity(fpath, md5):
        raise RuntimeError("File not found or corrupted.")
Sanyam Kapoor's avatar
Sanyam Kapoor committed
156
157


158
def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> List[str]:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
159
160
161
162
163
164
165
166
    """List all directories at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the directories found
    """
    root = os.path.expanduser(root)
167
    directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
Sanyam Kapoor's avatar
Sanyam Kapoor committed
168
169
170
171
172
    if prefix is True:
        directories = [os.path.join(root, d) for d in directories]
    return directories


173
def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> List[str]:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
174
175
176
177
178
179
180
181
182
183
    """List all files ending with a suffix at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
            It uses the Python "str.endswith" method and is passed directly
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the files found
    """
    root = os.path.expanduser(root)
184
    files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
Sanyam Kapoor's avatar
Sanyam Kapoor committed
185
186
187
    if prefix is True:
        files = [os.path.join(root, d) for d in files]
    return files
188
189


190
def download_file_from_google_drive(
191
192
193
194
    file_id: str,
    root: Union[str, pathlib.Path],
    filename: Optional[Union[str, pathlib.Path]] = None,
    md5: Optional[str] = None,
195
):
196
197
198
199
200
201
202
203
    """Download a Google Drive file from  and place it in root.

    Args:
        file_id (str): id of file to be downloaded
        root (str): Directory to place downloaded file in
        filename (str, optional): Name to save the file under. If None, use the id of the file.
        md5 (str, optional): MD5 checksum of the download. If None, do not check
    """
204
205
206
207
208
209
    try:
        import gdown
    except ModuleNotFoundError:
        raise RuntimeError(
            "To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
        )
210

211
212
213
    root = os.path.expanduser(root)
    if not filename:
        filename = file_id
214
    fpath = os.fspath(os.path.join(root, filename))
215

216
217
    os.makedirs(root, exist_ok=True)

218
219
    if check_integrity(fpath, md5):
        print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
220
        return
221

222
223
224
225
    gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT)

    if not check_integrity(fpath, md5):
        raise RuntimeError("File not found or corrupted.")
226

227

228
229
230
def _extract_tar(
    from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
231
232
    with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
        tar.extractall(to_path)
Ardalan's avatar
Ardalan committed
233
234


235
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
236
    ".bz2": zipfile.ZIP_BZIP2,
237
238
    ".xz": zipfile.ZIP_LZMA,
}
239
240


241
242
243
def _extract_zip(
    from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
244
245
246
247
    with zipfile.ZipFile(
        from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
    ) as zip:
        zip.extractall(to_path)
248
249


250
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
251
252
253
    ".tar": _extract_tar,
    ".zip": _extract_zip,
}
254
255
256
257
258
259
260
261
262
263
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
    ".bz2": bz2.open,
    ".gz": gzip.open,
    ".xz": lzma.open,
}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
    ".tbz": (".tar", ".bz2"),
    ".tbz2": (".tar", ".bz2"),
    ".tgz": (".tar", ".gz"),
}
264
265


266
def _detect_file_type(file: Union[str, pathlib.Path]) -> Tuple[str, Optional[str], Optional[str]]:
267
    """Detect the archive type and/or compression of a file.
268

269
270
    Args:
        file (str): the filename
271

272
273
    Returns:
        (tuple): tuple of suffix, archive type, and compression
274

275
276
277
    Raises:
        RuntimeError: if file has no suffix or suffix is not supported
    """
278
279
280
281
282
    suffixes = pathlib.Path(file).suffixes
    if not suffixes:
        raise RuntimeError(
            f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
        )
283
    suffix = suffixes[-1]
284
285

    # check if the suffix is a known alias
286
    if suffix in _FILE_TYPE_ALIASES:
287
288
289
        return (suffix, *_FILE_TYPE_ALIASES[suffix])

    # check if the suffix is an archive type
290
    if suffix in _ARCHIVE_EXTRACTORS:
291
292
293
        return suffix, suffix, None

    # check if the suffix is a compression
294
295
296
297
298
299
300
301
302
    if suffix in _COMPRESSED_FILE_OPENERS:
        # check for suffix hierarchy
        if len(suffixes) > 1:
            suffix2 = suffixes[-2]

            # check if the suffix2 is an archive type
            if suffix2 in _ARCHIVE_EXTRACTORS:
                return suffix2 + suffix, suffix2, suffix

303
304
        return suffix, None, suffix

305
306
    valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
    raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
307
308


309
310
311
312
313
def _decompress(
    from_path: Union[str, pathlib.Path],
    to_path: Optional[Union[str, pathlib.Path]] = None,
    remove_finished: bool = False,
) -> pathlib.Path:
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    r"""Decompress a file.

    The compression is automatically detected from the file name.

    Args:
        from_path (str): Path to the file to be decompressed.
        to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
        remove_finished (bool): If ``True``, remove the file after the extraction.

    Returns:
        (str): Path to the decompressed file.
    """
    suffix, archive_type, compression = _detect_file_type(from_path)
    if not compression:
        raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")

330
    if to_path is None:
331
        to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))
332

333
334
335
336
337
    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
    compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]

    with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
        wfh.write(rfh.read())
338
339

    if remove_finished:
340
341
        os.remove(from_path)

342
    return pathlib.Path(to_path)
343
344


345
346
347
348
349
def extract_archive(
    from_path: Union[str, pathlib.Path],
    to_path: Optional[Union[str, pathlib.Path]] = None,
    remove_finished: bool = False,
) -> Union[str, pathlib.Path]:
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    """Extract an archive.

    The archive type and a possible compression is automatically detected from the file name. If the file is compressed
    but not an archive the call is dispatched to :func:`decompress`.

    Args:
        from_path (str): Path to the file to be extracted.
        to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
            used.
        remove_finished (bool): If ``True``, remove the file after the extraction.

    Returns:
        (str): Path to the directory the file was extracted to.
    """
364
365
366
367
368
369
370

    def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
        if isinstance(from_path, str):
            return os.fspath(ret_path)
        else:
            return ret_path

371
372
373
374
375
    if to_path is None:
        to_path = os.path.dirname(from_path)

    suffix, archive_type, compression = _detect_file_type(from_path)
    if not archive_type:
376
        ret_path = _decompress(
377
378
379
380
            from_path,
            os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
            remove_finished=remove_finished,
        )
381
        return path_or_str(ret_path)
382
383
384
385
386

    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
    extractor = _ARCHIVE_EXTRACTORS[archive_type]

    extractor(from_path, to_path, compression)
387
388
    if remove_finished:
        os.remove(from_path)
389

390
    return path_or_str(pathlib.Path(to_path))
391

392

393
394
def download_and_extract_archive(
    url: str,
395
396
397
    download_root: Union[str, pathlib.Path],
    extract_root: Optional[Union[str, pathlib.Path]] = None,
    filename: Optional[Union[str, pathlib.Path]] = None,
398
399
400
    md5: Optional[str] = None,
    remove_finished: bool = False,
) -> None:
401
402
403
404
405
    download_root = os.path.expanduser(download_root)
    if extract_root is None:
        extract_root = download_root
    if not filename:
        filename = os.path.basename(url)
406

407
    download_url(url, download_root, filename, md5)
408

409
    archive = os.path.join(download_root, filename)
410
    print(f"Extracting {archive} to {extract_root}")
411
    extract_archive(archive, extract_root, remove_finished)
412
413


414
def iterable_to_str(iterable: Iterable) -> str:
415
416
417
    return "'" + "', '".join([str(item) for item in iterable]) + "'"


418
419
420
421
T = TypeVar("T", str, bytes)


def verify_str_arg(
422
423
    value: T,
    arg: Optional[str] = None,
424
    valid_values: Optional[Iterable[T]] = None,
425
    custom_msg: Optional[str] = None,
426
) -> T:
427
    if not isinstance(value, str):
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        if arg is None:
            msg = "Expected type str, but got type {type}."
        else:
            msg = "Expected type str for argument {arg}, but got type {type}."
        msg = msg.format(type=type(value), arg=arg)
        raise ValueError(msg)

    if valid_values is None:
        return value

    if value not in valid_values:
        if custom_msg is not None:
            msg = custom_msg
        else:
442
            msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
443
            msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
444
445
446
        raise ValueError(msg)

    return value
447
448


449
def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    """Read file in .pfm format. Might contain either 1 or 3 channels of data.

    Args:
        file_name (str): Path to the file.
        slice_channels (int): Number of channels to slice out of the file.
            Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
    """

    with open(file_name, "rb") as f:
        header = f.readline().rstrip()
        if header not in [b"PF", b"Pf"]:
            raise ValueError("Invalid PFM file")

        dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
        if not dim_match:
            raise Exception("Malformed PFM header.")
        w, h = (int(dim) for dim in dim_match.groups())

        scale = float(f.readline().rstrip())
        if scale < 0:  # little-endian
            endian = "<"
            scale = -scale
        else:
            endian = ">"  # big-endian

        data = np.fromfile(f, dtype=endian + "f")

    pfm_channels = 3 if header == b"PF" else 1

    data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
    data = np.flip(data, axis=1)  # flip on h dimension
    data = data[:slice_channels, :, :]
    return data.astype(np.float32)
Philip Meier's avatar
Philip Meier committed
483
484
485
486
487
488


def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
    return (
        t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
    )