io_utils.py 3.01 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) OpenMMLab. All rights reserved.
import hashlib
import os

import click
import mmcv
import requests
import torch.distributed as dist
from mmcv.runner import get_dist_info
from requests.exceptions import InvalidURL, RequestException, Timeout

MMGEN_CACHE_DIR = os.path.expanduser('~') + '/.cache/openmmlab/mmgen/'


def get_content_from_url(url, timeout=15, stream=False):
    """Get content from url.

    Args:
        url (str): Url for getting content.
        timeout (int): Set the socket timeout. Default: 15.
    """
    try:
        response = requests.get(url, timeout=timeout, stream=stream)
    except InvalidURL as err:
        raise err  # type: ignore
    except Timeout as err:
        raise err  # type: ignore
    except RequestException as err:
        raise err  # type: ignore
    except Exception as err:
        raise err  # type: ignore
    return response


def download_from_url(url,
                      dest_path=None,
                      dest_dir=MMGEN_CACHE_DIR,
                      hash_prefix=None):
    """Download object at the given URL to a local path.
    Args:
        url (str): URL of the object to download.
        dest_path (str): Path where object will be saved.
        dest_dir (str): The directory of the destination. Defaults to
            ``'~/.cache/openmmlab/mmgen/'``.
        hash_prefix (string, optional): If not None, the SHA256 downloaded
            file should start with `hash_prefix`. Default: None.

    Return:
        str: path for the downloaded file.
    """
    # get the exact destination path
    if dest_path is None:
        filename = url.split('/')[-1]
        dest_path = os.path.join(dest_dir, filename)

    if dest_path.startswith('~'):
        dest_path = os.path.expanduser('~') + dest_path[1:]

    # advoid downloading existed file
    if os.path.exists(dest_path):
        return dest_path

    rank, ws = get_dist_info()

    # only download from the master process
    if rank == 0:
        # mkdir
        _dir = os.path.dirname(dest_path)
        mmcv.mkdir_or_exist(_dir)

        if hash_prefix is not None:
            sha256 = hashlib.sha256()

        response = get_content_from_url(url, stream=True)
        size = int(response.headers.get('content-length'))
        with open(dest_path, 'wb') as fw:
            content_iter = response.iter_content(chunk_size=1024)
            with click.progressbar(content_iter, length=size / 1024) as chunks:
                for chunk in chunks:
                    if chunk:
                        fw.write(chunk)
                        fw.flush()
                        if hash_prefix is not None:
                            sha256.update(chunk)

        if hash_prefix is not None:
            digest = sha256.hexdigest()
            if digest[:len(hash_prefix)] != hash_prefix:
                raise RuntimeError(
                    f'invalid hash value, expected "{hash_prefix}", but got '
                    f'"{digest}"')

    # sync the other processes
    if ws > 1:
        dist.barrier()

    return dest_path