utils.py 21.5 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
"""Dataset utilities."""
2
from __future__ import absolute_import
Minjie Wang's avatar
Minjie Wang committed
3

4
5
import errno
import hashlib
VoVAllen's avatar
VoVAllen committed
6
import os
7
import pickle
VoVAllen's avatar
VoVAllen committed
8
import sys
Minjie Wang's avatar
Minjie Wang committed
9
import warnings
VoVAllen's avatar
VoVAllen committed
10

11
12
import networkx.algorithms as A

13
14
import numpy as np
import requests
VoVAllen's avatar
VoVAllen committed
15

16
from .. import backend as F
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from .graph_serialize import load_graphs, load_labels, save_graphs
from .tensor_serialize import load_tensors, save_tensors

__all__ = [
    "loadtxt",
    "download",
    "check_sha1",
    "extract_archive",
    "get_download_dir",
    "Subset",
    "split_dataset",
    "save_graphs",
    "load_graphs",
    "load_labels",
    "save_tensors",
    "load_tensors",
    "add_nodepred_split",
34
35
    "add_node_property_split",
    "mask_nodes_by_property",
36
]
VoVAllen's avatar
VoVAllen committed
37

38

39
40
41
def loadtxt(path, delimiter, dtype=None):
    try:
        import pandas as pd
42

43
44
45
        df = pd.read_csv(path, delimiter=delimiter, header=None)
        return df.values
    except ImportError:
46
47
48
49
        warnings.warn(
            "Pandas is not installed, now using numpy.loadtxt to load data, "
            "which could be extremely slow. Accelerate by installing pandas"
        )
50
        return np.loadtxt(path, delimiter=delimiter)
51

52

Haibin Lin's avatar
Haibin Lin committed
53
54
def _get_dgl_url(file_url):
    """Get DGL online url for download."""
55
56
57
58
    dgl_repo_url = "https://data.dgl.ai/"
    repo_url = os.environ.get("DGL_REPO", dgl_repo_url)
    if repo_url[-1] != "/":
        repo_url = repo_url + "/"
Haibin Lin's avatar
Haibin Lin committed
59
60
61
    return repo_url + file_url


VoVAllen's avatar
VoVAllen committed
62
def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    """Split dataset into training, validation and test set.

    Parameters
    ----------
    dataset
        We assume ``len(dataset)`` gives the number of datapoints and ``dataset[i]``
        gives the ith datapoint.
    frac_list : list or None, optional
        A list of length 3 containing the fraction to use for training,
        validation and test. If None, we will use [0.8, 0.1, 0.1].
    shuffle : bool, optional
        By default we perform a consecutive split of the dataset. If True,
        we will first randomly shuffle the dataset.
    random_state : None, int or array_like, optional
        Random seed used to initialize the pseudo-random number generator.
        Can be any integer between 0 and 2**32 - 1 inclusive, an array
        (or other sequence) of such integers, or None (the default).
        If seed is None, then RandomState will try to read data from /dev/urandom
        (or the Windows analogue) if available or seed from the clock otherwise.

    Returns
    -------
    list of length 3
        Subsets for training, validation and test.
    """
VoVAllen's avatar
VoVAllen committed
88
    from itertools import accumulate
89

VoVAllen's avatar
VoVAllen committed
90
91
    if frac_list is None:
        frac_list = [0.8, 0.1, 0.1]
92
    frac_list = np.asarray(frac_list)
93
94
95
    assert np.allclose(
        np.sum(frac_list), 1.0
    ), "Expect frac_list sum to 1, got {:.4f}".format(np.sum(frac_list))
VoVAllen's avatar
VoVAllen committed
96
97
98
99
    num_data = len(dataset)
    lengths = (num_data * frac_list).astype(int)
    lengths[-1] = num_data - np.sum(lengths[:-1])
    if shuffle:
100
        indices = np.random.RandomState(seed=random_state).permutation(num_data)
VoVAllen's avatar
VoVAllen committed
101
102
    else:
        indices = np.arange(num_data)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    return [
        Subset(dataset, indices[offset - length : offset])
        for offset, length in zip(accumulate(lengths), lengths)
    ]


def download(
    url,
    path=None,
    overwrite=True,
    sha1_hash=None,
    retries=5,
    verify_ssl=True,
    log=True,
):
Mufei Li's avatar
Mufei Li committed
118
    """Download a given URL.
Minjie Wang's avatar
Minjie Wang committed
119
120
121
122
123
124

    Codes borrowed from mxnet/gluon/utils.py

    Parameters
    ----------
    url : str
Mufei Li's avatar
Mufei Li committed
125
        URL to download.
Minjie Wang's avatar
Minjie Wang committed
126
127
    path : str, optional
        Destination path to store downloaded file. By default stores to the
Mufei Li's avatar
Mufei Li committed
128
        current directory with the same name as in url.
Minjie Wang's avatar
Minjie Wang committed
129
    overwrite : bool, optional
Mufei Li's avatar
Mufei Li committed
130
        Whether to overwrite the destination file if it already exists.
131
        By default always overwrites the downloaded file.
Minjie Wang's avatar
Minjie Wang committed
132
133
134
135
    sha1_hash : str, optional
        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
        but doesn't match.
    retries : integer, default 5
Mufei Li's avatar
Mufei Li committed
136
        The number of times to attempt downloading in case of failure or non 200 return codes.
Minjie Wang's avatar
Minjie Wang committed
137
138
    verify_ssl : bool, default True
        Verify SSL certificates.
139
140
    log : bool, default True
        Whether to print the progress for download
Minjie Wang's avatar
Minjie Wang committed
141
142
143
144
145
146
147

    Returns
    -------
    str
        The file path of the downloaded file.
    """
    if path is None:
148
        fname = url.split("/")[-1]
Minjie Wang's avatar
Minjie Wang committed
149
        # Empty filenames are invalid
150
151
152
153
        assert fname, (
            "Can't construct file-name from this URL. "
            "Please set the `path` option manually."
        )
Minjie Wang's avatar
Minjie Wang committed
154
155
156
    else:
        path = os.path.expanduser(path)
        if os.path.isdir(path):
157
            fname = os.path.join(path, url.split("/")[-1])
Minjie Wang's avatar
Minjie Wang committed
158
159
160
161
162
163
        else:
            fname = path
    assert retries >= 0, "Number of retries should be at least 0"

    if not verify_ssl:
        warnings.warn(
164
165
166
167
168
169
170
171
172
            "Unverified HTTPS request is being made (verify_ssl=False). "
            "Adding certificate verification is strongly advised."
        )

    if (
        overwrite
        or not os.path.exists(fname)
        or (sha1_hash and not check_sha1(fname, sha1_hash))
    ):
Minjie Wang's avatar
Minjie Wang committed
173
174
175
        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
        if not os.path.exists(dirname):
            os.makedirs(dirname)
176
        while retries + 1 > 0:
Minjie Wang's avatar
Minjie Wang committed
177
178
179
            # Disable pyling too broad Exception
            # pylint: disable=W0703
            try:
180
                if log:
181
                    print("Downloading %s from %s..." % (fname, url))
Minjie Wang's avatar
Minjie Wang committed
182
183
                r = requests.get(url, stream=True, verify=verify_ssl)
                if r.status_code != 200:
VoVAllen's avatar
VoVAllen committed
184
                    raise RuntimeError("Failed downloading url %s" % url)
185
                with open(fname, "wb") as f:
Minjie Wang's avatar
Minjie Wang committed
186
                    for chunk in r.iter_content(chunk_size=1024):
VoVAllen's avatar
VoVAllen committed
187
                        if chunk:  # filter out keep-alive new chunks
Minjie Wang's avatar
Minjie Wang committed
188
189
                            f.write(chunk)
                if sha1_hash and not check_sha1(fname, sha1_hash):
190
191
192
193
194
195
                    raise UserWarning(
                        "File {} is downloaded but the content hash does not match."
                        " The repo may be outdated or download may be incomplete. "
                        'If the "repo_url" is overridden, consider switching to '
                        "the default repo.".format(fname)
                    )
Minjie Wang's avatar
Minjie Wang committed
196
197
198
199
200
201
                break
            except Exception as e:
                retries -= 1
                if retries <= 0:
                    raise e
                else:
202
                    if log:
203
204
205
206
207
                        print(
                            "download failed, retrying, {} attempt{} left".format(
                                retries, "s" if retries > 1 else ""
                            )
                        )
Minjie Wang's avatar
Minjie Wang committed
208
209
210

    return fname

VoVAllen's avatar
VoVAllen committed
211

Minjie Wang's avatar
Minjie Wang committed
212
213
214
215
216
217
218
219
220
221
222
def check_sha1(filename, sha1_hash):
    """Check whether the sha1 hash of the file content matches the expected hash.

    Codes borrowed from mxnet/gluon/utils.py

    Parameters
    ----------
    filename : str
        Path to the file.
    sha1_hash : str
        Expected sha1 hash in hexadecimal digits.
Mufei Li's avatar
Mufei Li committed
223

Minjie Wang's avatar
Minjie Wang committed
224
225
226
227
228
229
    Returns
    -------
    bool
        Whether the file content matches the expected hash.
    """
    sha1 = hashlib.sha1()
230
    with open(filename, "rb") as f:
Minjie Wang's avatar
Minjie Wang committed
231
232
233
234
235
236
237
238
        while True:
            data = f.read(1048576)
            if not data:
                break
            sha1.update(data)

    return sha1.hexdigest() == sha1_hash

VoVAllen's avatar
VoVAllen committed
239

240
def extract_archive(file, target_dir, overwrite=False):
Mufei Li's avatar
Mufei Li committed
241
    """Extract archive file.
Minjie Wang's avatar
Minjie Wang committed
242
243
244
245
246
247

    Parameters
    ----------
    file : str
        Absolute path of the archive file.
    target_dir : str
Mufei Li's avatar
Mufei Li committed
248
        Target directory of the archive to be uncompressed.
249
250
251
    overwrite : bool, default True
        Whether to overwrite the contents inside the directory.
        By default always overwrites.
Minjie Wang's avatar
Minjie Wang committed
252
    """
253
    if os.path.exists(target_dir) and not overwrite:
254
        return
255
256
257
258
259
260
    print("Extracting file to {}".format(target_dir))
    if (
        file.endswith(".tar.gz")
        or file.endswith(".tar")
        or file.endswith(".tgz")
    ):
261
        import tarfile
262
263

        with tarfile.open(file, "r") as archive:
264

TrellixVulnTeam's avatar
TrellixVulnTeam committed
265
266
267
268
269
            def is_within_directory(directory, target):
                abs_directory = os.path.abspath(directory)
                abs_target = os.path.abspath(target)
                prefix = os.path.commonprefix([abs_directory, abs_target])
                return prefix == abs_directory
270
271
272
273

            def safe_extract(
                tar, path=".", members=None, *, numeric_owner=False
            ):
TrellixVulnTeam's avatar
TrellixVulnTeam committed
274
275
276
277
                for member in tar.getmembers():
                    member_path = os.path.join(path, member.name)
                    if not is_within_directory(path, member_path):
                        raise Exception("Attempted Path Traversal in Tar File")
278
279
                tar.extractall(path, members, numeric_owner=numeric_owner)

TrellixVulnTeam's avatar
TrellixVulnTeam committed
280
            safe_extract(archive, path=target_dir)
281
    elif file.endswith(".gz"):
282
283
        import gzip
        import shutil
284
285

        with gzip.open(file, "rb") as f_in:
286
            target_file = os.path.join(target_dir, os.path.basename(file)[:-3])
287
            with open(target_file, "wb") as f_out:
288
                shutil.copyfileobj(f_in, f_out)
289
    elif file.endswith(".zip"):
290
        import zipfile
291
292

        with zipfile.ZipFile(file, "r") as archive:
293
            archive.extractall(path=target_dir)
Minjie Wang's avatar
Minjie Wang committed
294
    else:
295
        raise Exception("Unrecognized file type: " + file)
Minjie Wang's avatar
Minjie Wang committed
296

VoVAllen's avatar
VoVAllen committed
297

Minjie Wang's avatar
Minjie Wang committed
298
def get_download_dir():
Mufei Li's avatar
Mufei Li committed
299
300
301
302
303
304
305
    """Get the absolute path to the download directory.

    Returns
    -------
    dirname : str
        Path to the download directory
    """
306
307
    default_dir = os.path.join(os.path.expanduser("~"), ".dgl")
    dirname = os.environ.get("DGL_DOWNLOAD_DIR", default_dir)
Minjie Wang's avatar
Minjie Wang committed
308
309
310
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    return dirname
VoVAllen's avatar
VoVAllen committed
311

312

313
314
315
316
317
318
319
def makedirs(path):
    try:
        os.makedirs(os.path.expanduser(os.path.normpath(path)))
    except OSError as e:
        if e.errno != errno.EEXIST and os.path.isdir(path):
            raise e

320

321
def save_info(path, info):
322
    """Save dataset related information into disk.
323
324
325
326
327
328
329
330

    Parameters
    ----------
    path : str
        File to save information.
    info : dict
        A python dict storing information to save on disk.
    """
331
    with open(path, "wb") as pf:
332
333
334
335
        pickle.dump(info, pf)


def load_info(path):
336
    """Load dataset related information from disk.
337
338
339
340
341
342
343
344
345
346
347
348
349
350

    Parameters
    ----------
    path : str
        File to load information from.

    Returns
    -------
    info : dict
        A python dict storing information loaded from disk.
    """
    with open(path, "rb") as pf:
        info = pickle.load(pf)
    return info
VoVAllen's avatar
VoVAllen committed
351

352

353
def deprecate_property(old, new):
354
355
356
357
358
    warnings.warn(
        "Property {} will be deprecated, please use {} instead.".format(
            old, new
        )
    )
359
360
361


def deprecate_function(old, new):
362
363
364
365
366
    warnings.warn(
        "Function {} will be deprecated, please use {} instead.".format(
            old, new
        )
    )
367
368
369


def deprecate_class(old, new):
370
371
372
373
    warnings.warn(
        "Class {} will be deprecated, please use {} instead.".format(old, new)
    )

374
375
376
377
378
379
380

def idx2mask(idx, len):
    """Create mask."""
    mask = np.zeros(len)
    mask[idx] = 1
    return mask

381

382
383
384
385
386
387
388
389
390
def generate_mask_tensor(mask):
    """Generate mask tensor according to different backend
    For torch and tensorflow, it will create a bool tensor
    For mxnet, it will create a float tensor
    Parameters
    ----------
    mask: numpy ndarray
        input mask tensor
    """
391
392
393
394
395
    assert isinstance(mask, np.ndarray), (
        "input for generate_mask_tensor" "should be an numpy ndarray"
    )
    if F.backend_name == "mxnet":
        return F.tensor(mask, dtype=F.data_type_dict["float32"])
396
    else:
397
398
        return F.tensor(mask, dtype=F.data_type_dict["bool"])

399

VoVAllen's avatar
VoVAllen committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
class Subset(object):
    """Subset of a dataset at specified indices

    Code adapted from PyTorch.

    Parameters
    ----------
    dataset
        dataset[i] should return the ith datapoint
    indices : list
        List of datapoint indices to construct the subset
    """

    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, item):
        """Get the datapoint indexed by item

        Returns
        -------
        tuple
            datapoint
        """
        return self.dataset[self.indices[item]]

    def __len__(self):
        """Get subset size

        Returns
        -------
        int
            Number of datapoints in the subset
        """
        return len(self.indices)
436

437

438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
def add_nodepred_split(dataset, ratio, ntype=None):
    """Split the given dataset into training, validation and test sets for
    transductive node predction task.

    It adds three node mask arrays ``'train_mask'``, ``'val_mask'`` and ``'test_mask'``,
    to each graph in the dataset. Each sample in the dataset thus must be a :class:`DGLGraph`.

    Fix the random seed of NumPy to make the result deterministic::

        numpy.random.seed(42)

    Parameters
    ----------
    dataset : DGLDataset
        The dataset to modify.
    ratio : (float, float, float)
        Split ratios for training, validation and test sets. Must sum to one.
    ntype : str, optional
        The node type to add mask for.

    Examples
    --------
    >>> dataset = dgl.data.AmazonCoBuyComputerDataset()
    >>> print('train_mask' in dataset[0].ndata)
    False
    >>> dgl.data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
    >>> print('train_mask' in dataset[0].ndata)
    True
    """
    if len(ratio) != 3:
468
469
470
        raise ValueError(
            f"Split ratio must be a float triplet but got {ratio}."
        )
471
472
473
474
475
    for i in range(len(dataset)):
        g = dataset[i]
        n = g.num_nodes(ntype)
        idx = np.arange(0, n)
        np.random.shuffle(idx)
476
477
478
479
480
        n_train, n_val, n_test = (
            int(n * ratio[0]),
            int(n * ratio[1]),
            int(n * ratio[2]),
        )
481
        train_mask = generate_mask_tensor(idx2mask(idx[:n_train], n))
482
483
484
485
486
487
488
        val_mask = generate_mask_tensor(
            idx2mask(idx[n_train : n_train + n_val], n)
        )
        test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val :], n))
        g.nodes[ntype].data["train_mask"] = train_mask
        g.nodes[ntype].data["val_mask"] = val_mask
        g.nodes[ntype].data["test_mask"] = test_mask
489
490
491
492
493


def mask_nodes_by_property(property_values, part_ratios, random_seed=None):
    """Provide the split masks for a node split with distributional shift based on a given
    node property, as proposed in `Evaluating Robustness and Uncertainty of Graph Models
494
    Under Structural Distributional Shifts <https://arxiv.org/abs/2302.13875>`__
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571

    It considers the in-distribution (ID) and out-of-distribution (OOD) subsets of nodes.
    The ID subset includes training, validation and testing parts, while the OOD subset
    includes validation and testing parts. It sorts the nodes in the ascending order of
    their property values, splits them into 5 non-intersecting parts, and creates 5
    associated node mask arrays:
        - 3 for the ID nodes: ``'in_train_mask'``, ``'in_valid_mask'``, ``'in_test_mask'``,
        - and 2 for the OOD nodes: ``'out_valid_mask'``, ``'out_test_mask'``.

    Parameters
    ----------
    property_values : numpy ndarray
        The node property (float) values by which the dataset will be split.
        The length of the array must be equal to the number of nodes in graph.
    part_ratios : list
        A list of 5 ratios for training, ID validation, ID test,
        OOD validation, OOD testing parts. The values in the list must sum to one.
    random_seed : int, optional
        Random seed to fix for the initial permutation of nodes. It is
        used to create a random order for the nodes that have the same
        property values or belong to the ID subset. (default: None)

    Returns
    ----------
    split_masks : dict
        A python dict storing the mask names as keys and the corresponding
        node mask arrays as values.

    Examples
    --------
    >>> num_nodes = 1000
    >>> property_values = np.random.uniform(size=num_nodes)
    >>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
    >>> split_masks = dgl.data.utils.mask_nodes_by_property(property_values, part_ratios)
    >>> print('in_valid_mask' in split_masks)
    True
    """

    num_nodes = len(property_values)
    part_sizes = np.round(num_nodes * np.array(part_ratios)).astype(int)
    part_sizes[-1] -= np.sum(part_sizes) - num_nodes

    generator = np.random.RandomState(random_seed)
    permutation = generator.permutation(num_nodes)

    node_indices = np.arange(num_nodes)[permutation]
    property_values = property_values[permutation]
    in_distribution_size = np.sum(part_sizes[:3])

    node_indices_ordered = node_indices[np.argsort(property_values)]
    node_indices_ordered[:in_distribution_size] = generator.permutation(
        node_indices_ordered[:in_distribution_size]
    )

    sections = np.cumsum(part_sizes)
    node_split = np.split(node_indices_ordered, sections)[:-1]
    mask_names = [
        "in_train_mask",
        "in_valid_mask",
        "in_test_mask",
        "out_valid_mask",
        "out_test_mask",
    ]
    split_masks = {}

    for mask_name, node_indices in zip(mask_names, node_split):
        split_mask = idx2mask(node_indices, num_nodes)
        split_masks[mask_name] = generate_mask_tensor(split_mask)

    return split_masks


def add_node_property_split(
    dataset, part_ratios, property_name, ascending=True, random_seed=None
):
    """Create a node split with distributional shift based on a given node property,
    as proposed in `Evaluating Robustness and Uncertainty of Graph Models Under
572
    Structural Distributional Shifts <https://arxiv.org/abs/2302.13875>`__
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676

    It splits the nodes of each graph in the given dataset into 5 non-intersecting
    parts based on their structural properties. This can be used for transductive node
    prediction task with distributional shifts.

    It considers the in-distribution (ID) and out-of-distribution (OOD) subsets of nodes.
    The ID subset includes training, validation and testing parts, while the OOD subset
    includes validation and testing parts. As a result, it creates 5 associated node mask
    arrays for each graph:
        - 3 for the ID nodes: ``'in_train_mask'``, ``'in_valid_mask'``, ``'in_test_mask'``,
        - and 2 for the OOD nodes: ``'out_valid_mask'``, ``'out_test_mask'``.

    This function implements 3 particular strategies for inducing distributional shifts
    in graph — based on **popularity**, **locality** or **density**.

    Parameters
    ----------
    dataset : :class:`~DGLDataset` or list of :class:`~dgl.DGLGraph`
        The dataset to induce structural distributional shift.
    part_ratios : list
        A list of 5 ratio values for training, ID validation, ID test,
        OOD validation and OOD test parts. The values must sum to 1.0.
    property_name : str
        The name of the node property to be used, which must be
        ``'popularity'``, ``'locality'`` or ``'density'``.
    ascending : bool, optional
        Whether to sort nodes in the ascending order of the node property,
        so that nodes with greater values of the property are considered
        to be OOD (default: True)
    random_seed : int, optional
        Random seed to fix for the initial permutation of nodes. It is
        used to create a random order for the nodes that have the same
        property values or belong to the ID subset. (default: None)

    Examples
    --------
    >>> dataset = dgl.data.AmazonCoBuyComputerDataset()
    >>> print('in_valid_mask' in dataset[0].ndata)
    False
    >>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
    >>> property_name = 'popularity'
    >>> dgl.data.utils.add_node_property_split(dataset, part_ratios, property_name)
    >>> print('in_valid_mask' in dataset[0].ndata)
    True
    """

    assert property_name in [
        "popularity",
        "locality",
        "density",
    ], "The name of property has to be 'popularity', 'locality', or 'density'"

    assert len(part_ratios) == 5, "part_ratios must contain 5 values"

    import networkx as nx

    for idx in range(len(dataset)):
        graph_dgl = dataset[idx]
        graph_nx = nx.Graph(graph_dgl.to_networkx())

        compute_property_fn = _property_name_to_compute_fn[property_name]
        property_values = compute_property_fn(graph_nx, ascending)

        node_masks = mask_nodes_by_property(
            property_values, part_ratios, random_seed
        )

        for mask_name, node_mask in node_masks.items():
            graph_dgl.ndata[mask_name] = node_mask


def _compute_popularity_property(graph_nx, ascending=True):
    direction = -1 if ascending else 1
    property_values = direction * np.array(list(A.pagerank(graph_nx).values()))
    return property_values


def _compute_locality_property(graph_nx, ascending=True):
    num_nodes = graph_nx.number_of_nodes()
    pagerank_values = np.array(list(A.pagerank(graph_nx).values()))

    personalization = dict(zip(range(num_nodes), [0.0] * num_nodes))
    personalization[np.argmax(pagerank_values)] = 1.0

    direction = -1 if ascending else 1
    property_values = direction * np.array(
        list(A.pagerank(graph_nx, personalization=personalization).values())
    )
    return property_values


def _compute_density_property(graph_nx, ascending=True):
    direction = -1 if ascending else 1
    property_values = direction * np.array(
        list(A.clustering(graph_nx).values())
    )
    return property_values


_property_name_to_compute_fn = {
    "popularity": _compute_popularity_property,
    "locality": _compute_locality_property,
    "density": _compute_density_property,
}