utils.py 10.3 KB
Newer Older
soumith's avatar
soumith committed
1
import os
soumith's avatar
soumith committed
2
import os.path
soumith's avatar
soumith committed
3
import hashlib
4
5
import gzip
import tarfile
6
from typing import Any, Callable, List, Iterable, Optional, TypeVar
7
8
import zipfile

9
import torch
10
from torch.utils.model_zoo import tqdm
11
12


13
def gen_bar_updater() -> Callable[[int, int, int], None]:
Francisco Massa's avatar
Francisco Massa committed
14
15
    pbar = tqdm(total=None)

16
    def bar_update(count, block_size, total_size):
Holger Kohr's avatar
Holger Kohr committed
17
18
19
20
        if pbar.total is None and total_size:
            pbar.total = total_size
        progress_bytes = count * block_size
        pbar.update(progress_bytes - pbar.n)
21
22

    return bar_update
soumith's avatar
soumith committed
23

soumith's avatar
soumith committed
24

25
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
26
27
28
29
30
31
32
    md5 = hashlib.md5()
    with open(fpath, 'rb') as f:
        for chunk in iter(lambda: f.read(chunk_size), b''):
            md5.update(chunk)
    return md5.hexdigest()


33
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
34
35
36
    return md5 == calculate_md5(fpath, **kwargs)


37
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
38
39
    if not os.path.isfile(fpath):
        return False
40
41
    if md5 is None:
        return True
42
    return check_md5(fpath, md5)
43
44


45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def _get_redirect_url(url: str, max_hops: int = 10) -> str:
    import requests

    for hop in range(max_hops + 1):
        response = requests.get(url)

        if response.url == url or response.url is None:
            return url

        url = response.url
    else:
        raise RecursionError(f"Too many redirects: {max_hops + 1})")


def download_url(
    url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
62
63
64
65
66
    """Download a file from a url and place it in root.

    Args:
        url (str): URL to download file from
        root (str): Directory to place downloaded file in
67
68
        filename (str, optional): Name to save the file under. If None, use the basename of the URL
        md5 (str, optional): MD5 checksum of the download. If None, do not check
69
        max_redirect_hops (int, optional): Maximum number of redirect hops allowed
70
    """
Philip Meier's avatar
Philip Meier committed
71
    import urllib
72
73

    root = os.path.expanduser(root)
74
75
    if not filename:
        filename = os.path.basename(url)
76
77
    fpath = os.path.join(root, filename)

78
    os.makedirs(root, exist_ok=True)
79

80
    # check if file is already present locally
81
    if check_integrity(fpath, md5):
82
        print('Using downloaded and verified file: ' + fpath)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        return

    # expand redirect chain if needed
    url = _get_redirect_url(url, max_hops=max_redirect_hops)

    # download the file
    try:
        print('Downloading ' + url + ' to ' + fpath)
        urllib.request.urlretrieve(
            url, fpath,
            reporthook=gen_bar_updater()
        )
    except (urllib.error.URLError, IOError) as e:  # type: ignore[attr-defined]
        if url[:5] == 'https':
            url = url.replace('https:', 'http:')
            print('Failed download. Trying https -> http instead.'
                  ' Downloading ' + url + ' to ' + fpath)
Holger Kohr's avatar
Holger Kohr committed
100
101
            urllib.request.urlretrieve(
                url, fpath,
Francisco Massa's avatar
Francisco Massa committed
102
                reporthook=gen_bar_updater()
Holger Kohr's avatar
Holger Kohr committed
103
            )
104
105
106
107
108
        else:
            raise e
    # check integrity of downloaded file
    if not check_integrity(fpath, md5):
        raise RuntimeError("File not found or corrupted.")
Sanyam Kapoor's avatar
Sanyam Kapoor committed
109
110


111
def list_dir(root: str, prefix: bool = False) -> List[str]:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
112
113
114
115
116
117
118
119
    """List all directories at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the directories found
    """
    root = os.path.expanduser(root)
120
    directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
Sanyam Kapoor's avatar
Sanyam Kapoor committed
121
122
123
124
125
    if prefix is True:
        directories = [os.path.join(root, d) for d in directories]
    return directories


126
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
127
128
129
130
131
132
133
134
135
136
    """List all files ending with a suffix at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
            It uses the Python "str.endswith" method and is passed directly
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the files found
    """
    root = os.path.expanduser(root)
137
    files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
Sanyam Kapoor's avatar
Sanyam Kapoor committed
138
139
140
    if prefix is True:
        files = [os.path.join(root, d) for d in files]
    return files
141
142


143
def _quota_exceeded(response: "requests.models.Response") -> bool:  # type: ignore[name-defined]
144
145
146
    return False
    # See https://github.com/pytorch/vision/issues/2992 for details
    # return "Google Drive - Quota exceeded" in response.text
147
148


149
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    """Download a Google Drive file from  and place it in root.

    Args:
        file_id (str): id of file to be downloaded
        root (str): Directory to place downloaded file in
        filename (str, optional): Name to save the file under. If None, use the id of the file.
        md5 (str, optional): MD5 checksum of the download. If None, do not check
    """
    # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
    import requests
    url = "https://docs.google.com/uc?export=download"

    root = os.path.expanduser(root)
    if not filename:
        filename = file_id
    fpath = os.path.join(root, filename)

167
    os.makedirs(root, exist_ok=True)
168
169
170
171
172
173
174
175
176
177
178
179
180

    if os.path.isfile(fpath) and check_integrity(fpath, md5):
        print('Using downloaded and verified file: ' + fpath)
    else:
        session = requests.Session()

        response = session.get(url, params={'id': file_id}, stream=True)
        token = _get_confirm_token(response)

        if token:
            params = {'id': file_id, 'confirm': token}
            response = session.get(url, params=params, stream=True)

181
182
183
184
185
186
187
188
        if _quota_exceeded(response):
            msg = (
                f"The daily quota of the file {filename} is exceeded and it "
                f"can't be downloaded. This is a limitation of Google Drive "
                f"and can only be overcome by trying again later."
            )
            raise RuntimeError(msg)

189
190
191
        _save_response_content(response, fpath)


192
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]:  # type: ignore[name-defined]
193
194
195
196
197
198
199
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None


200
201
202
def _save_response_content(
    response: "requests.models.Response", destination: str, chunk_size: int = 32768,  # type: ignore[name-defined]
) -> None:
203
204
205
206
207
208
209
210
211
    with open(destination, "wb") as f:
        pbar = tqdm(total=None)
        progress = 0
        for chunk in response.iter_content(chunk_size):
            if chunk:  # filter out keep-alive new chunks
                f.write(chunk)
                progress += len(chunk)
                pbar.update(progress - pbar.n)
        pbar.close()
212
213


214
def _is_tarxz(filename: str) -> bool:
Ardalan's avatar
Ardalan committed
215
216
217
    return filename.endswith(".tar.xz")


218
def _is_tar(filename: str) -> bool:
219
220
221
    return filename.endswith(".tar")


222
def _is_targz(filename: str) -> bool:
223
224
225
    return filename.endswith(".tar.gz")


226
def _is_tgz(filename: str) -> bool:
227
228
229
    return filename.endswith(".tgz")


230
def _is_gzip(filename: str) -> bool:
231
232
233
    return filename.endswith(".gz") and not filename.endswith(".tar.gz")


234
def _is_zip(filename: str) -> bool:
235
236
237
    return filename.endswith(".zip")


238
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None:
239
240
241
    if to_path is None:
        to_path = os.path.dirname(from_path)

242
    if _is_tar(from_path):
243
        with tarfile.open(from_path, 'r') as tar:
244
            tar.extractall(path=to_path)
245
    elif _is_targz(from_path) or _is_tgz(from_path):
246
247
        with tarfile.open(from_path, 'r:gz') as tar:
            tar.extractall(path=to_path)
Philip Meier's avatar
Philip Meier committed
248
    elif _is_tarxz(from_path):
Ardalan's avatar
Ardalan committed
249
250
        with tarfile.open(from_path, 'r:xz') as tar:
            tar.extractall(path=to_path)
251
252
253
254
255
256
257
258
259
260
261
    elif _is_gzip(from_path):
        to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
        with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
            out_f.write(zip_f.read())
    elif _is_zip(from_path):
        with zipfile.ZipFile(from_path, 'r') as z:
            z.extractall(to_path)
    else:
        raise ValueError("Extraction of {} not supported".format(from_path))

    if remove_finished:
262
263
264
        os.remove(from_path)


265
266
267
268
269
270
271
272
def download_and_extract_archive(
    url: str,
    download_root: str,
    extract_root: Optional[str] = None,
    filename: Optional[str] = None,
    md5: Optional[str] = None,
    remove_finished: bool = False,
) -> None:
273
274
275
276
277
    download_root = os.path.expanduser(download_root)
    if extract_root is None:
        extract_root = download_root
    if not filename:
        filename = os.path.basename(url)
278

279
    download_url(url, download_root, filename, md5)
280

281
282
283
    archive = os.path.join(download_root, filename)
    print("Extracting {} to {}".format(archive, extract_root))
    extract_archive(archive, extract_root, remove_finished)
284
285


286
def iterable_to_str(iterable: Iterable) -> str:
287
288
289
    return "'" + "', '".join([str(item) for item in iterable]) + "'"


290
291
292
293
294
295
T = TypeVar("T", str, bytes)


def verify_str_arg(
    value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None,
) -> T:
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    if not isinstance(value, torch._six.string_classes):
        if arg is None:
            msg = "Expected type str, but got type {type}."
        else:
            msg = "Expected type str for argument {arg}, but got type {type}."
        msg = msg.format(type=type(value), arg=arg)
        raise ValueError(msg)

    if valid_values is None:
        return value

    if value not in valid_values:
        if custom_msg is not None:
            msg = custom_msg
        else:
            msg = ("Unknown value '{value}' for argument {arg}. "
                   "Valid values are {{{valid_values}}}.")
            msg = msg.format(value=value, arg=arg,
                             valid_values=iterable_to_str(valid_values))
        raise ValueError(msg)

    return value