utils.py 7 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
"""Dataset utilities."""
2
from __future__ import absolute_import
Minjie Wang's avatar
Minjie Wang committed
3

VoVAllen's avatar
VoVAllen committed
4
5
import os
import sys
Minjie Wang's avatar
Minjie Wang committed
6
7
8
9
import hashlib
import warnings
import zipfile
import tarfile
VoVAllen's avatar
VoVAllen committed
10
import numpy as np
Minjie Wang's avatar
Minjie Wang committed
11
12
13
14
15
16
17
try:
    import requests
except ImportError:
    class requests_failed_to_import(object):
        pass
    requests = requests_failed_to_import

VoVAllen's avatar
VoVAllen committed
18
19
20
__all__ = ['download', 'check_sha1', 'extract_archive',
           'get_download_dir', 'Subset', 'split_dataset']

21

Haibin Lin's avatar
Haibin Lin committed
22
23
24
25
26
27
28
29
30
def _get_dgl_url(file_url):
    """Get DGL online url for download."""
    dgl_repo_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/'
    repo_url = os.environ.get('DGL_REPO', dgl_repo_url)
    if repo_url[-1] != '/':
        repo_url = repo_url + '/'
    return repo_url + file_url


VoVAllen's avatar
VoVAllen committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
    from itertools import accumulate
    if frac_list is None:
        frac_list = [0.8, 0.1, 0.1]
    frac_list = np.array(frac_list)
    assert np.allclose(np.sum(frac_list), 1.), \
        'Expect frac_list sum to 1, got {:.4f}'.format(
            np.sum(frac_list))
    num_data = len(dataset)
    lengths = (num_data * frac_list).astype(int)
    lengths[-1] = num_data - np.sum(lengths[:-1])
    if shuffle:
        indices = np.random.RandomState(
            seed=random_state).permutation(num_data)
    else:
        indices = np.arange(num_data)
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]


Minjie Wang's avatar
Minjie Wang committed
50
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
Mufei Li's avatar
Mufei Li committed
51
    """Download a given URL.
Minjie Wang's avatar
Minjie Wang committed
52
53
54
55
56
57

    Codes borrowed from mxnet/gluon/utils.py

    Parameters
    ----------
    url : str
Mufei Li's avatar
Mufei Li committed
58
        URL to download.
Minjie Wang's avatar
Minjie Wang committed
59
60
    path : str, optional
        Destination path to store downloaded file. By default stores to the
Mufei Li's avatar
Mufei Li committed
61
        current directory with the same name as in url.
Minjie Wang's avatar
Minjie Wang committed
62
    overwrite : bool, optional
Mufei Li's avatar
Mufei Li committed
63
        Whether to overwrite the destination file if it already exists.
Minjie Wang's avatar
Minjie Wang committed
64
65
66
67
    sha1_hash : str, optional
        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
        but doesn't match.
    retries : integer, default 5
Mufei Li's avatar
Mufei Li committed
68
        The number of times to attempt downloading in case of failure or non 200 return codes.
Minjie Wang's avatar
Minjie Wang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    verify_ssl : bool, default True
        Verify SSL certificates.

    Returns
    -------
    str
        The file path of the downloaded file.
    """
    if path is None:
        fname = url.split('/')[-1]
        # Empty filenames are invalid
        assert fname, 'Can\'t construct file-name from this URL. ' \
            'Please set the `path` option manually.'
    else:
        path = os.path.expanduser(path)
        if os.path.isdir(path):
            fname = os.path.join(path, url.split('/')[-1])
        else:
            fname = path
    assert retries >= 0, "Number of retries should be at least 0"

    if not verify_ssl:
        warnings.warn(
            'Unverified HTTPS request is being made (verify_ssl=False). '
            'Adding certificate verification is strongly advised.')

    if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        while retries+1 > 0:
            # Disable pyling too broad Exception
            # pylint: disable=W0703
            try:
VoVAllen's avatar
VoVAllen committed
103
                print('Downloading %s from %s...' % (fname, url))
Minjie Wang's avatar
Minjie Wang committed
104
105
                r = requests.get(url, stream=True, verify=verify_ssl)
                if r.status_code != 200:
VoVAllen's avatar
VoVAllen committed
106
                    raise RuntimeError("Failed downloading url %s" % url)
Minjie Wang's avatar
Minjie Wang committed
107
108
                with open(fname, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=1024):
VoVAllen's avatar
VoVAllen committed
109
                        if chunk:  # filter out keep-alive new chunks
Minjie Wang's avatar
Minjie Wang committed
110
111
                            f.write(chunk)
                if sha1_hash and not check_sha1(fname, sha1_hash):
VoVAllen's avatar
VoVAllen committed
112
113
114
                    raise UserWarning('File {} is downloaded but the content hash does not match.'
                                      ' The repo may be outdated or download may be incomplete. '
                                      'If the "repo_url" is overridden, consider switching to '
Minjie Wang's avatar
Minjie Wang committed
115
116
117
118
119
120
121
122
123
124
125
126
                                      'the default repo.'.format(fname))
                break
            except Exception as e:
                retries -= 1
                if retries <= 0:
                    raise e
                else:
                    print("download failed, retrying, {} attempt{} left"
                          .format(retries, 's' if retries > 1 else ''))

    return fname

VoVAllen's avatar
VoVAllen committed
127

Minjie Wang's avatar
Minjie Wang committed
128
129
130
131
132
133
134
135
136
137
138
def check_sha1(filename, sha1_hash):
    """Check whether the sha1 hash of the file content matches the expected hash.

    Codes borrowed from mxnet/gluon/utils.py

    Parameters
    ----------
    filename : str
        Path to the file.
    sha1_hash : str
        Expected sha1 hash in hexadecimal digits.
Mufei Li's avatar
Mufei Li committed
139

Minjie Wang's avatar
Minjie Wang committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    Returns
    -------
    bool
        Whether the file content matches the expected hash.
    """
    sha1 = hashlib.sha1()
    with open(filename, 'rb') as f:
        while True:
            data = f.read(1048576)
            if not data:
                break
            sha1.update(data)

    return sha1.hexdigest() == sha1_hash

VoVAllen's avatar
VoVAllen committed
155

Minjie Wang's avatar
Minjie Wang committed
156
def extract_archive(file, target_dir):
Mufei Li's avatar
Mufei Li committed
157
    """Extract archive file.
Minjie Wang's avatar
Minjie Wang committed
158
159
160
161
162
163

    Parameters
    ----------
    file : str
        Absolute path of the archive file.
    target_dir : str
Mufei Li's avatar
Mufei Li committed
164
        Target directory of the archive to be uncompressed.
Minjie Wang's avatar
Minjie Wang committed
165
    """
166
167
    if os.path.exists(target_dir):
        return
Minjie Wang's avatar
Minjie Wang committed
168
169
170
171
172
173
    if file.endswith('.gz') or file.endswith('.tar') or file.endswith('.tgz'):
        archive = tarfile.open(file, 'r')
    elif file.endswith('.zip'):
        archive = zipfile.ZipFile(file, 'r')
    else:
        raise Exception('Unrecognized file type: ' + file)
174
    print('Extracting file to {}'.format(target_dir))
Minjie Wang's avatar
Minjie Wang committed
175
176
177
    archive.extractall(path=target_dir)
    archive.close()

VoVAllen's avatar
VoVAllen committed
178

Minjie Wang's avatar
Minjie Wang committed
179
def get_download_dir():
Mufei Li's avatar
Mufei Li committed
180
181
182
183
184
185
186
    """Get the absolute path to the download directory.

    Returns
    -------
    dirname : str
        Path to the download directory
    """
Gan Quan's avatar
Gan Quan committed
187
188
    default_dir = os.path.join(os.path.expanduser('~'), '.dgl')
    dirname = os.environ.get('DGL_DOWNLOAD_DIR', default_dir)
Minjie Wang's avatar
Minjie Wang committed
189
190
191
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    return dirname
VoVAllen's avatar
VoVAllen committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229


class Subset(object):
    """Subset of a dataset at specified indices

    Code adapted from PyTorch.

    Parameters
    ----------
    dataset
        dataset[i] should return the ith datapoint
    indices : list
        List of datapoint indices to construct the subset
    """

    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, item):
        """Get the datapoint indexed by item

        Returns
        -------
        tuple
            datapoint
        """
        return self.dataset[self.indices[item]]

    def __len__(self):
        """Get subset size

        Returns
        -------
        int
            Number of datapoints in the subset
        """
        return len(self.indices)