utils.py 8.33 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
"""Dataset utilities."""
2
from __future__ import absolute_import
Minjie Wang's avatar
Minjie Wang committed
3

VoVAllen's avatar
VoVAllen committed
4
5
import os
import sys
Minjie Wang's avatar
Minjie Wang committed
6
7
8
9
import hashlib
import warnings
import zipfile
import tarfile
VoVAllen's avatar
VoVAllen committed
10
import numpy as np
VoVAllen's avatar
VoVAllen committed
11
12
13

from .graph_serialize import save_graphs, load_graphs, load_labels

Minjie Wang's avatar
Minjie Wang committed
14
15
16
17
18
19
20
try:
    import requests
except ImportError:
    class requests_failed_to_import(object):
        pass
    requests = requests_failed_to_import

VoVAllen's avatar
VoVAllen committed
21
__all__ = ['download', 'check_sha1', 'extract_archive',
VoVAllen's avatar
VoVAllen committed
22
23
           'get_download_dir', 'Subset', 'split_dataset',
           'save_graphs', "load_graphs", "load_labels"]
VoVAllen's avatar
VoVAllen committed
24

25

Haibin Lin's avatar
Haibin Lin committed
26
27
28
29
30
31
32
33
34
def _get_dgl_url(file_url):
    """Get DGL online url for download."""
    dgl_repo_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/'
    repo_url = os.environ.get('DGL_REPO', dgl_repo_url)
    if repo_url[-1] != '/':
        repo_url = repo_url + '/'
    return repo_url + file_url


VoVAllen's avatar
VoVAllen committed
35
def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    """Split dataset into training, validation and test set.

    Parameters
    ----------
    dataset
        We assume ``len(dataset)`` gives the number of datapoints and ``dataset[i]``
        gives the ith datapoint.
    frac_list : list or None, optional
        A list of length 3 containing the fraction to use for training,
        validation and test. If None, we will use [0.8, 0.1, 0.1].
    shuffle : bool, optional
        By default we perform a consecutive split of the dataset. If True,
        we will first randomly shuffle the dataset.
    random_state : None, int or array_like, optional
        Random seed used to initialize the pseudo-random number generator.
        Can be any integer between 0 and 2**32 - 1 inclusive, an array
        (or other sequence) of such integers, or None (the default).
        If seed is None, then RandomState will try to read data from /dev/urandom
        (or the Windows analogue) if available or seed from the clock otherwise.

    Returns
    -------
    list of length 3
        Subsets for training, validation and test.
    """
VoVAllen's avatar
VoVAllen committed
61
62
63
64
65
    from itertools import accumulate
    if frac_list is None:
        frac_list = [0.8, 0.1, 0.1]
    frac_list = np.array(frac_list)
    assert np.allclose(np.sum(frac_list), 1.), \
66
        'Expect frac_list sum to 1, got {:.4f}'.format(np.sum(frac_list))
VoVAllen's avatar
VoVAllen committed
67
68
69
70
71
72
73
74
75
76
77
    num_data = len(dataset)
    lengths = (num_data * frac_list).astype(int)
    lengths[-1] = num_data - np.sum(lengths[:-1])
    if shuffle:
        indices = np.random.RandomState(
            seed=random_state).permutation(num_data)
    else:
        indices = np.arange(num_data)
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]


78
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True, log=True):
Mufei Li's avatar
Mufei Li committed
79
    """Download a given URL.
Minjie Wang's avatar
Minjie Wang committed
80
81
82
83
84
85

    Codes borrowed from mxnet/gluon/utils.py

    Parameters
    ----------
    url : str
Mufei Li's avatar
Mufei Li committed
86
        URL to download.
Minjie Wang's avatar
Minjie Wang committed
87
88
    path : str, optional
        Destination path to store downloaded file. By default stores to the
Mufei Li's avatar
Mufei Li committed
89
        current directory with the same name as in url.
Minjie Wang's avatar
Minjie Wang committed
90
    overwrite : bool, optional
Mufei Li's avatar
Mufei Li committed
91
        Whether to overwrite the destination file if it already exists.
Minjie Wang's avatar
Minjie Wang committed
92
93
94
95
    sha1_hash : str, optional
        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
        but doesn't match.
    retries : integer, default 5
Mufei Li's avatar
Mufei Li committed
96
        The number of times to attempt downloading in case of failure or non 200 return codes.
Minjie Wang's avatar
Minjie Wang committed
97
98
    verify_ssl : bool, default True
        Verify SSL certificates.
99
100
    log : bool, default True
        Whether to print the progress for download
Minjie Wang's avatar
Minjie Wang committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

    Returns
    -------
    str
        The file path of the downloaded file.
    """
    if path is None:
        fname = url.split('/')[-1]
        # Empty filenames are invalid
        assert fname, 'Can\'t construct file-name from this URL. ' \
            'Please set the `path` option manually.'
    else:
        path = os.path.expanduser(path)
        if os.path.isdir(path):
            fname = os.path.join(path, url.split('/')[-1])
        else:
            fname = path
    assert retries >= 0, "Number of retries should be at least 0"

    if not verify_ssl:
        warnings.warn(
            'Unverified HTTPS request is being made (verify_ssl=False). '
            'Adding certificate verification is strongly advised.')

    if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        while retries+1 > 0:
            # Disable pyling too broad Exception
            # pylint: disable=W0703
            try:
133
134
                if log:
                    print('Downloading %s from %s...' % (fname, url))
Minjie Wang's avatar
Minjie Wang committed
135
136
                r = requests.get(url, stream=True, verify=verify_ssl)
                if r.status_code != 200:
VoVAllen's avatar
VoVAllen committed
137
                    raise RuntimeError("Failed downloading url %s" % url)
Minjie Wang's avatar
Minjie Wang committed
138
139
                with open(fname, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=1024):
VoVAllen's avatar
VoVAllen committed
140
                        if chunk:  # filter out keep-alive new chunks
Minjie Wang's avatar
Minjie Wang committed
141
142
                            f.write(chunk)
                if sha1_hash and not check_sha1(fname, sha1_hash):
VoVAllen's avatar
VoVAllen committed
143
144
145
                    raise UserWarning('File {} is downloaded but the content hash does not match.'
                                      ' The repo may be outdated or download may be incomplete. '
                                      'If the "repo_url" is overridden, consider switching to '
Minjie Wang's avatar
Minjie Wang committed
146
147
148
149
150
151
152
                                      'the default repo.'.format(fname))
                break
            except Exception as e:
                retries -= 1
                if retries <= 0:
                    raise e
                else:
153
154
155
                    if log:
                        print("download failed, retrying, {} attempt{} left"
                              .format(retries, 's' if retries > 1 else ''))
Minjie Wang's avatar
Minjie Wang committed
156
157
158

    return fname

VoVAllen's avatar
VoVAllen committed
159

Minjie Wang's avatar
Minjie Wang committed
160
161
162
163
164
165
166
167
168
169
170
def check_sha1(filename, sha1_hash):
    """Check whether the sha1 hash of the file content matches the expected hash.

    Codes borrowed from mxnet/gluon/utils.py

    Parameters
    ----------
    filename : str
        Path to the file.
    sha1_hash : str
        Expected sha1 hash in hexadecimal digits.
Mufei Li's avatar
Mufei Li committed
171

Minjie Wang's avatar
Minjie Wang committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    Returns
    -------
    bool
        Whether the file content matches the expected hash.
    """
    sha1 = hashlib.sha1()
    with open(filename, 'rb') as f:
        while True:
            data = f.read(1048576)
            if not data:
                break
            sha1.update(data)

    return sha1.hexdigest() == sha1_hash

VoVAllen's avatar
VoVAllen committed
187

Minjie Wang's avatar
Minjie Wang committed
188
def extract_archive(file, target_dir):
Mufei Li's avatar
Mufei Li committed
189
    """Extract archive file.
Minjie Wang's avatar
Minjie Wang committed
190
191
192
193
194
195

    Parameters
    ----------
    file : str
        Absolute path of the archive file.
    target_dir : str
Mufei Li's avatar
Mufei Li committed
196
        Target directory of the archive to be uncompressed.
Minjie Wang's avatar
Minjie Wang committed
197
    """
198
199
    if os.path.exists(target_dir):
        return
Minjie Wang's avatar
Minjie Wang committed
200
201
202
203
204
205
    if file.endswith('.gz') or file.endswith('.tar') or file.endswith('.tgz'):
        archive = tarfile.open(file, 'r')
    elif file.endswith('.zip'):
        archive = zipfile.ZipFile(file, 'r')
    else:
        raise Exception('Unrecognized file type: ' + file)
206
    print('Extracting file to {}'.format(target_dir))
Minjie Wang's avatar
Minjie Wang committed
207
208
209
    archive.extractall(path=target_dir)
    archive.close()

VoVAllen's avatar
VoVAllen committed
210

Minjie Wang's avatar
Minjie Wang committed
211
def get_download_dir():
Mufei Li's avatar
Mufei Li committed
212
213
214
215
216
217
218
    """Get the absolute path to the download directory.

    Returns
    -------
    dirname : str
        Path to the download directory
    """
Gan Quan's avatar
Gan Quan committed
219
220
    default_dir = os.path.join(os.path.expanduser('~'), '.dgl')
    dirname = os.environ.get('DGL_DOWNLOAD_DIR', default_dir)
Minjie Wang's avatar
Minjie Wang committed
221
222
223
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    return dirname
VoVAllen's avatar
VoVAllen committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261


class Subset(object):
    """Subset of a dataset at specified indices

    Code adapted from PyTorch.

    Parameters
    ----------
    dataset
        dataset[i] should return the ith datapoint
    indices : list
        List of datapoint indices to construct the subset
    """

    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, item):
        """Get the datapoint indexed by item

        Returns
        -------
        tuple
            datapoint
        """
        return self.dataset[self.indices[item]]

    def __len__(self):
        """Get subset size

        Returns
        -------
        int
            Number of datapoints in the subset
        """
        return len(self.indices)