utils.py 21.8 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
"""Dataset utilities."""
2
from __future__ import absolute_import
Minjie Wang's avatar
Minjie Wang committed
3

4
5
import errno
import hashlib
VoVAllen's avatar
VoVAllen committed
6
import os
7
import pickle
VoVAllen's avatar
VoVAllen committed
8
import sys
Minjie Wang's avatar
Minjie Wang committed
9
import warnings
VoVAllen's avatar
VoVAllen committed
10

11
12
import networkx.algorithms as A

13
14
import numpy as np
import requests
15
from tqdm.auto import tqdm
VoVAllen's avatar
VoVAllen committed
16

17
from .. import backend as F
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from .graph_serialize import load_graphs, load_labels, save_graphs
from .tensor_serialize import load_tensors, save_tensors

__all__ = [
    "loadtxt",
    "download",
    "check_sha1",
    "extract_archive",
    "get_download_dir",
    "Subset",
    "split_dataset",
    "save_graphs",
    "load_graphs",
    "load_labels",
    "save_tensors",
    "load_tensors",
    "add_nodepred_split",
35
36
    "add_node_property_split",
    "mask_nodes_by_property",
37
]
VoVAllen's avatar
VoVAllen committed
38

39

40
41
42
def loadtxt(path, delimiter, dtype=None):
    try:
        import pandas as pd
43

44
45
46
        df = pd.read_csv(path, delimiter=delimiter, header=None)
        return df.values
    except ImportError:
47
48
49
50
        warnings.warn(
            "Pandas is not installed, now using numpy.loadtxt to load data, "
            "which could be extremely slow. Accelerate by installing pandas"
        )
51
        return np.loadtxt(path, delimiter=delimiter)
52

53

Haibin Lin's avatar
Haibin Lin committed
54
55
def _get_dgl_url(file_url):
    """Get DGL online url for download."""
56
57
58
59
    dgl_repo_url = "https://data.dgl.ai/"
    repo_url = os.environ.get("DGL_REPO", dgl_repo_url)
    if repo_url[-1] != "/":
        repo_url = repo_url + "/"
Haibin Lin's avatar
Haibin Lin committed
60
61
62
    return repo_url + file_url


VoVAllen's avatar
VoVAllen committed
63
def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    """Split dataset into training, validation and test set.

    Parameters
    ----------
    dataset
        We assume ``len(dataset)`` gives the number of datapoints and ``dataset[i]``
        gives the ith datapoint.
    frac_list : list or None, optional
        A list of length 3 containing the fraction to use for training,
        validation and test. If None, we will use [0.8, 0.1, 0.1].
    shuffle : bool, optional
        By default we perform a consecutive split of the dataset. If True,
        we will first randomly shuffle the dataset.
    random_state : None, int or array_like, optional
        Random seed used to initialize the pseudo-random number generator.
        Can be any integer between 0 and 2**32 - 1 inclusive, an array
        (or other sequence) of such integers, or None (the default).
        If seed is None, then RandomState will try to read data from /dev/urandom
        (or the Windows analogue) if available or seed from the clock otherwise.

    Returns
    -------
    list of length 3
        Subsets for training, validation and test.
    """
VoVAllen's avatar
VoVAllen committed
89
    from itertools import accumulate
90

VoVAllen's avatar
VoVAllen committed
91
92
    if frac_list is None:
        frac_list = [0.8, 0.1, 0.1]
93
    frac_list = np.asarray(frac_list)
94
95
96
    assert np.allclose(
        np.sum(frac_list), 1.0
    ), "Expect frac_list sum to 1, got {:.4f}".format(np.sum(frac_list))
VoVAllen's avatar
VoVAllen committed
97
98
99
100
    num_data = len(dataset)
    lengths = (num_data * frac_list).astype(int)
    lengths[-1] = num_data - np.sum(lengths[:-1])
    if shuffle:
101
        indices = np.random.RandomState(seed=random_state).permutation(num_data)
VoVAllen's avatar
VoVAllen committed
102
103
    else:
        indices = np.arange(num_data)
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    return [
        Subset(dataset, indices[offset - length : offset])
        for offset, length in zip(accumulate(lengths), lengths)
    ]


def download(
    url,
    path=None,
    overwrite=True,
    sha1_hash=None,
    retries=5,
    verify_ssl=True,
    log=True,
):
Mufei Li's avatar
Mufei Li committed
119
    """Download a given URL.
Minjie Wang's avatar
Minjie Wang committed
120
121
122
123
124
125

    Codes borrowed from mxnet/gluon/utils.py

    Parameters
    ----------
    url : str
Mufei Li's avatar
Mufei Li committed
126
        URL to download.
Minjie Wang's avatar
Minjie Wang committed
127
128
    path : str, optional
        Destination path to store downloaded file. By default stores to the
Mufei Li's avatar
Mufei Li committed
129
        current directory with the same name as in url.
Minjie Wang's avatar
Minjie Wang committed
130
    overwrite : bool, optional
Mufei Li's avatar
Mufei Li committed
131
        Whether to overwrite the destination file if it already exists.
132
        By default always overwrites the downloaded file.
Minjie Wang's avatar
Minjie Wang committed
133
134
135
136
    sha1_hash : str, optional
        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
        but doesn't match.
    retries : integer, default 5
Mufei Li's avatar
Mufei Li committed
137
        The number of times to attempt downloading in case of failure or non 200 return codes.
Minjie Wang's avatar
Minjie Wang committed
138
139
    verify_ssl : bool, default True
        Verify SSL certificates.
140
141
    log : bool, default True
        Whether to print the progress for download
Minjie Wang's avatar
Minjie Wang committed
142
143
144
145
146
147
148

    Returns
    -------
    str
        The file path of the downloaded file.
    """
    if path is None:
149
        fname = url.split("/")[-1]
Minjie Wang's avatar
Minjie Wang committed
150
        # Empty filenames are invalid
151
152
153
154
        assert fname, (
            "Can't construct file-name from this URL. "
            "Please set the `path` option manually."
        )
Minjie Wang's avatar
Minjie Wang committed
155
156
157
    else:
        path = os.path.expanduser(path)
        if os.path.isdir(path):
158
            fname = os.path.join(path, url.split("/")[-1])
Minjie Wang's avatar
Minjie Wang committed
159
160
161
162
163
164
        else:
            fname = path
    assert retries >= 0, "Number of retries should be at least 0"

    if not verify_ssl:
        warnings.warn(
165
166
167
168
169
170
171
172
173
            "Unverified HTTPS request is being made (verify_ssl=False). "
            "Adding certificate verification is strongly advised."
        )

    if (
        overwrite
        or not os.path.exists(fname)
        or (sha1_hash and not check_sha1(fname, sha1_hash))
    ):
Minjie Wang's avatar
Minjie Wang committed
174
175
176
        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
        if not os.path.exists(dirname):
            os.makedirs(dirname)
177
        while retries + 1 > 0:
Minjie Wang's avatar
Minjie Wang committed
178
179
180
            # Disable pyling too broad Exception
            # pylint: disable=W0703
            try:
181
                if log:
182
                    print("Downloading %s from %s..." % (fname, url))
Minjie Wang's avatar
Minjie Wang committed
183
184
                r = requests.get(url, stream=True, verify=verify_ssl)
                if r.status_code != 200:
VoVAllen's avatar
VoVAllen committed
185
                    raise RuntimeError("Failed downloading url %s" % url)
186
187
188
189
190
191
192
193
194
195
                # Get the total file size.
                total_size = int(r.headers.get("content-length", 0))
                with tqdm(
                    total=total_size, unit="B", unit_scale=True, desc=fname
                ) as bar:
                    with open(fname, "wb") as f:
                        for chunk in r.iter_content(chunk_size=1024):
                            if chunk:  # filter out keep-alive new chunks
                                f.write(chunk)
                                bar.update(len(chunk))
Minjie Wang's avatar
Minjie Wang committed
196
                if sha1_hash and not check_sha1(fname, sha1_hash):
197
198
199
200
201
202
                    raise UserWarning(
                        "File {} is downloaded but the content hash does not match."
                        " The repo may be outdated or download may be incomplete. "
                        'If the "repo_url" is overridden, consider switching to '
                        "the default repo.".format(fname)
                    )
Minjie Wang's avatar
Minjie Wang committed
203
204
205
206
207
208
                break
            except Exception as e:
                retries -= 1
                if retries <= 0:
                    raise e
                else:
209
                    if log:
210
211
212
213
214
                        print(
                            "download failed, retrying, {} attempt{} left".format(
                                retries, "s" if retries > 1 else ""
                            )
                        )
Minjie Wang's avatar
Minjie Wang committed
215
216
217

    return fname

VoVAllen's avatar
VoVAllen committed
218

Minjie Wang's avatar
Minjie Wang committed
219
220
221
222
223
224
225
226
227
228
229
def check_sha1(filename, sha1_hash):
    """Check whether the sha1 hash of the file content matches the expected hash.

    Codes borrowed from mxnet/gluon/utils.py

    Parameters
    ----------
    filename : str
        Path to the file.
    sha1_hash : str
        Expected sha1 hash in hexadecimal digits.
Mufei Li's avatar
Mufei Li committed
230

Minjie Wang's avatar
Minjie Wang committed
231
232
233
234
235
236
    Returns
    -------
    bool
        Whether the file content matches the expected hash.
    """
    sha1 = hashlib.sha1()
237
    with open(filename, "rb") as f:
Minjie Wang's avatar
Minjie Wang committed
238
239
240
241
242
243
244
245
        while True:
            data = f.read(1048576)
            if not data:
                break
            sha1.update(data)

    return sha1.hexdigest() == sha1_hash

VoVAllen's avatar
VoVAllen committed
246

247
def extract_archive(file, target_dir, overwrite=True):
Mufei Li's avatar
Mufei Li committed
248
    """Extract archive file.
Minjie Wang's avatar
Minjie Wang committed
249
250
251
252
253
254

    Parameters
    ----------
    file : str
        Absolute path of the archive file.
    target_dir : str
Mufei Li's avatar
Mufei Li committed
255
        Target directory of the archive to be uncompressed.
256
257
258
    overwrite : bool, default True
        Whether to overwrite the contents inside the directory.
        By default always overwrites.
Minjie Wang's avatar
Minjie Wang committed
259
    """
260
    if os.path.exists(target_dir) and not overwrite:
261
        return
262
263
264
265
266
267
    print("Extracting file to {}".format(target_dir))
    if (
        file.endswith(".tar.gz")
        or file.endswith(".tar")
        or file.endswith(".tgz")
    ):
268
        import tarfile
269
270

        with tarfile.open(file, "r") as archive:
271

TrellixVulnTeam's avatar
TrellixVulnTeam committed
272
273
274
275
276
            def is_within_directory(directory, target):
                abs_directory = os.path.abspath(directory)
                abs_target = os.path.abspath(target)
                prefix = os.path.commonprefix([abs_directory, abs_target])
                return prefix == abs_directory
277
278
279
280

            def safe_extract(
                tar, path=".", members=None, *, numeric_owner=False
            ):
TrellixVulnTeam's avatar
TrellixVulnTeam committed
281
282
283
284
                for member in tar.getmembers():
                    member_path = os.path.join(path, member.name)
                    if not is_within_directory(path, member_path):
                        raise Exception("Attempted Path Traversal in Tar File")
285
286
                tar.extractall(path, members, numeric_owner=numeric_owner)

TrellixVulnTeam's avatar
TrellixVulnTeam committed
287
            safe_extract(archive, path=target_dir)
288
    elif file.endswith(".gz"):
289
290
        import gzip
        import shutil
291
292

        with gzip.open(file, "rb") as f_in:
293
            target_file = os.path.join(target_dir, os.path.basename(file)[:-3])
294
            with open(target_file, "wb") as f_out:
295
                shutil.copyfileobj(f_in, f_out)
296
    elif file.endswith(".zip"):
297
        import zipfile
298
299

        with zipfile.ZipFile(file, "r") as archive:
300
            archive.extractall(path=target_dir)
Minjie Wang's avatar
Minjie Wang committed
301
    else:
302
        raise Exception("Unrecognized file type: " + file)
Minjie Wang's avatar
Minjie Wang committed
303

VoVAllen's avatar
VoVAllen committed
304

Minjie Wang's avatar
Minjie Wang committed
305
def get_download_dir():
Mufei Li's avatar
Mufei Li committed
306
307
308
309
310
311
312
    """Get the absolute path to the download directory.

    Returns
    -------
    dirname : str
        Path to the download directory
    """
313
314
    default_dir = os.path.join(os.path.expanduser("~"), ".dgl")
    dirname = os.environ.get("DGL_DOWNLOAD_DIR", default_dir)
Minjie Wang's avatar
Minjie Wang committed
315
316
317
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    return dirname
VoVAllen's avatar
VoVAllen committed
318

319

320
321
322
323
324
325
326
def makedirs(path):
    try:
        os.makedirs(os.path.expanduser(os.path.normpath(path)))
    except OSError as e:
        if e.errno != errno.EEXIST and os.path.isdir(path):
            raise e

327

328
def save_info(path, info):
329
    """Save dataset related information into disk.
330
331
332
333
334
335
336
337

    Parameters
    ----------
    path : str
        File to save information.
    info : dict
        A python dict storing information to save on disk.
    """
338
    with open(path, "wb") as pf:
339
340
341
342
        pickle.dump(info, pf)


def load_info(path):
343
    """Load dataset related information from disk.
344
345
346
347
348
349
350
351
352
353
354
355
356
357

    Parameters
    ----------
    path : str
        File to load information from.

    Returns
    -------
    info : dict
        A python dict storing information loaded from disk.
    """
    with open(path, "rb") as pf:
        info = pickle.load(pf)
    return info
VoVAllen's avatar
VoVAllen committed
358

359

360
def deprecate_property(old, new):
361
362
363
364
365
    warnings.warn(
        "Property {} will be deprecated, please use {} instead.".format(
            old, new
        )
    )
366
367
368


def deprecate_function(old, new):
369
370
371
372
373
    warnings.warn(
        "Function {} will be deprecated, please use {} instead.".format(
            old, new
        )
    )
374
375
376


def deprecate_class(old, new):
377
378
379
380
    warnings.warn(
        "Class {} will be deprecated, please use {} instead.".format(old, new)
    )

381
382
383
384
385
386
387

def idx2mask(idx, len):
    """Create mask."""
    mask = np.zeros(len)
    mask[idx] = 1
    return mask

388

389
390
391
392
393
394
395
396
397
def generate_mask_tensor(mask):
    """Generate mask tensor according to different backend
    For torch and tensorflow, it will create a bool tensor
    For mxnet, it will create a float tensor
    Parameters
    ----------
    mask: numpy ndarray
        input mask tensor
    """
398
399
400
401
402
    assert isinstance(mask, np.ndarray), (
        "input for generate_mask_tensor" "should be an numpy ndarray"
    )
    if F.backend_name == "mxnet":
        return F.tensor(mask, dtype=F.data_type_dict["float32"])
403
    else:
404
405
        return F.tensor(mask, dtype=F.data_type_dict["bool"])

406

VoVAllen's avatar
VoVAllen committed
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
class Subset(object):
    """Subset of a dataset at specified indices

    Code adapted from PyTorch.

    Parameters
    ----------
    dataset
        dataset[i] should return the ith datapoint
    indices : list
        List of datapoint indices to construct the subset
    """

    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, item):
        """Get the datapoint indexed by item

        Returns
        -------
        tuple
            datapoint
        """
        return self.dataset[self.indices[item]]

    def __len__(self):
        """Get subset size

        Returns
        -------
        int
            Number of datapoints in the subset
        """
        return len(self.indices)
443

444

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
def add_nodepred_split(dataset, ratio, ntype=None):
    """Split the given dataset into training, validation and test sets for
    transductive node predction task.

    It adds three node mask arrays ``'train_mask'``, ``'val_mask'`` and ``'test_mask'``,
    to each graph in the dataset. Each sample in the dataset thus must be a :class:`DGLGraph`.

    Fix the random seed of NumPy to make the result deterministic::

        numpy.random.seed(42)

    Parameters
    ----------
    dataset : DGLDataset
        The dataset to modify.
    ratio : (float, float, float)
        Split ratios for training, validation and test sets. Must sum to one.
    ntype : str, optional
        The node type to add mask for.

    Examples
    --------
    >>> dataset = dgl.data.AmazonCoBuyComputerDataset()
    >>> print('train_mask' in dataset[0].ndata)
    False
    >>> dgl.data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
    >>> print('train_mask' in dataset[0].ndata)
    True
    """
    if len(ratio) != 3:
475
476
477
        raise ValueError(
            f"Split ratio must be a float triplet but got {ratio}."
        )
478
479
480
481
482
    for i in range(len(dataset)):
        g = dataset[i]
        n = g.num_nodes(ntype)
        idx = np.arange(0, n)
        np.random.shuffle(idx)
483
484
485
486
487
        n_train, n_val, n_test = (
            int(n * ratio[0]),
            int(n * ratio[1]),
            int(n * ratio[2]),
        )
488
        train_mask = generate_mask_tensor(idx2mask(idx[:n_train], n))
489
490
491
492
493
494
495
        val_mask = generate_mask_tensor(
            idx2mask(idx[n_train : n_train + n_val], n)
        )
        test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val :], n))
        g.nodes[ntype].data["train_mask"] = train_mask
        g.nodes[ntype].data["val_mask"] = val_mask
        g.nodes[ntype].data["test_mask"] = test_mask
496
497
498
499
500


def mask_nodes_by_property(property_values, part_ratios, random_seed=None):
    """Provide the split masks for a node split with distributional shift based on a given
    node property, as proposed in `Evaluating Robustness and Uncertainty of Graph Models
501
    Under Structural Distributional Shifts <https://arxiv.org/abs/2302.13875>`__
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578

    It considers the in-distribution (ID) and out-of-distribution (OOD) subsets of nodes.
    The ID subset includes training, validation and testing parts, while the OOD subset
    includes validation and testing parts. It sorts the nodes in the ascending order of
    their property values, splits them into 5 non-intersecting parts, and creates 5
    associated node mask arrays:
        - 3 for the ID nodes: ``'in_train_mask'``, ``'in_valid_mask'``, ``'in_test_mask'``,
        - and 2 for the OOD nodes: ``'out_valid_mask'``, ``'out_test_mask'``.

    Parameters
    ----------
    property_values : numpy ndarray
        The node property (float) values by which the dataset will be split.
        The length of the array must be equal to the number of nodes in graph.
    part_ratios : list
        A list of 5 ratios for training, ID validation, ID test,
        OOD validation, OOD testing parts. The values in the list must sum to one.
    random_seed : int, optional
        Random seed to fix for the initial permutation of nodes. It is
        used to create a random order for the nodes that have the same
        property values or belong to the ID subset. (default: None)

    Returns
    ----------
    split_masks : dict
        A python dict storing the mask names as keys and the corresponding
        node mask arrays as values.

    Examples
    --------
    >>> num_nodes = 1000
    >>> property_values = np.random.uniform(size=num_nodes)
    >>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
    >>> split_masks = dgl.data.utils.mask_nodes_by_property(property_values, part_ratios)
    >>> print('in_valid_mask' in split_masks)
    True
    """

    num_nodes = len(property_values)
    part_sizes = np.round(num_nodes * np.array(part_ratios)).astype(int)
    part_sizes[-1] -= np.sum(part_sizes) - num_nodes

    generator = np.random.RandomState(random_seed)
    permutation = generator.permutation(num_nodes)

    node_indices = np.arange(num_nodes)[permutation]
    property_values = property_values[permutation]
    in_distribution_size = np.sum(part_sizes[:3])

    node_indices_ordered = node_indices[np.argsort(property_values)]
    node_indices_ordered[:in_distribution_size] = generator.permutation(
        node_indices_ordered[:in_distribution_size]
    )

    sections = np.cumsum(part_sizes)
    node_split = np.split(node_indices_ordered, sections)[:-1]
    mask_names = [
        "in_train_mask",
        "in_valid_mask",
        "in_test_mask",
        "out_valid_mask",
        "out_test_mask",
    ]
    split_masks = {}

    for mask_name, node_indices in zip(mask_names, node_split):
        split_mask = idx2mask(node_indices, num_nodes)
        split_masks[mask_name] = generate_mask_tensor(split_mask)

    return split_masks


def add_node_property_split(
    dataset, part_ratios, property_name, ascending=True, random_seed=None
):
    """Create a node split with distributional shift based on a given node property,
    as proposed in `Evaluating Robustness and Uncertainty of Graph Models Under
579
    Structural Distributional Shifts <https://arxiv.org/abs/2302.13875>`__
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683

    It splits the nodes of each graph in the given dataset into 5 non-intersecting
    parts based on their structural properties. This can be used for transductive node
    prediction task with distributional shifts.

    It considers the in-distribution (ID) and out-of-distribution (OOD) subsets of nodes.
    The ID subset includes training, validation and testing parts, while the OOD subset
    includes validation and testing parts. As a result, it creates 5 associated node mask
    arrays for each graph:
        - 3 for the ID nodes: ``'in_train_mask'``, ``'in_valid_mask'``, ``'in_test_mask'``,
        - and 2 for the OOD nodes: ``'out_valid_mask'``, ``'out_test_mask'``.

    This function implements 3 particular strategies for inducing distributional shifts
    in graph — based on **popularity**, **locality** or **density**.

    Parameters
    ----------
    dataset : :class:`~DGLDataset` or list of :class:`~dgl.DGLGraph`
        The dataset to induce structural distributional shift.
    part_ratios : list
        A list of 5 ratio values for training, ID validation, ID test,
        OOD validation and OOD test parts. The values must sum to 1.0.
    property_name : str
        The name of the node property to be used, which must be
        ``'popularity'``, ``'locality'`` or ``'density'``.
    ascending : bool, optional
        Whether to sort nodes in the ascending order of the node property,
        so that nodes with greater values of the property are considered
        to be OOD (default: True)
    random_seed : int, optional
        Random seed to fix for the initial permutation of nodes. It is
        used to create a random order for the nodes that have the same
        property values or belong to the ID subset. (default: None)

    Examples
    --------
    >>> dataset = dgl.data.AmazonCoBuyComputerDataset()
    >>> print('in_valid_mask' in dataset[0].ndata)
    False
    >>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
    >>> property_name = 'popularity'
    >>> dgl.data.utils.add_node_property_split(dataset, part_ratios, property_name)
    >>> print('in_valid_mask' in dataset[0].ndata)
    True
    """

    assert property_name in [
        "popularity",
        "locality",
        "density",
    ], "The name of property has to be 'popularity', 'locality', or 'density'"

    assert len(part_ratios) == 5, "part_ratios must contain 5 values"

    import networkx as nx

    for idx in range(len(dataset)):
        graph_dgl = dataset[idx]
        graph_nx = nx.Graph(graph_dgl.to_networkx())

        compute_property_fn = _property_name_to_compute_fn[property_name]
        property_values = compute_property_fn(graph_nx, ascending)

        node_masks = mask_nodes_by_property(
            property_values, part_ratios, random_seed
        )

        for mask_name, node_mask in node_masks.items():
            graph_dgl.ndata[mask_name] = node_mask


def _compute_popularity_property(graph_nx, ascending=True):
    direction = -1 if ascending else 1
    property_values = direction * np.array(list(A.pagerank(graph_nx).values()))
    return property_values


def _compute_locality_property(graph_nx, ascending=True):
    num_nodes = graph_nx.number_of_nodes()
    pagerank_values = np.array(list(A.pagerank(graph_nx).values()))

    personalization = dict(zip(range(num_nodes), [0.0] * num_nodes))
    personalization[np.argmax(pagerank_values)] = 1.0

    direction = -1 if ascending else 1
    property_values = direction * np.array(
        list(A.pagerank(graph_nx, personalization=personalization).values())
    )
    return property_values


def _compute_density_property(graph_nx, ascending=True):
    direction = -1 if ascending else 1
    property_values = direction * np.array(
        list(A.clustering(graph_nx).values())
    )
    return property_values


_property_name_to_compute_fn = {
    "popularity": _compute_popularity_property,
    "locality": _compute_locality_property,
    "density": _compute_density_property,
}