citation_graph.py 28.2 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
"""Cora, citeseer, pubmed dataset.

(lingfan): following dataset loading and preprocessing code from tkipf/gcn
https://github.com/tkipf/gcn/blob/master/gcn/utils.py
"""
6
7
from __future__ import absolute_import

8
import os, sys
Minjie Wang's avatar
Minjie Wang committed
9
import pickle as pkl
10

Minjie Wang's avatar
Minjie Wang committed
11
import networkx as nx
12
13

import numpy as np
Minjie Wang's avatar
Minjie Wang committed
14
15
import scipy.sparse as sp

16
17
from .. import backend as F, batch, convert
from ..convert import from_networkx, graph as dgl_graph, to_networkx
18
from ..transforms import reorder_graph
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from .dgl_dataset import DGLBuiltinDataset

from .utils import (
    _get_dgl_url,
    deprecate_function,
    deprecate_property,
    generate_mask_tensor,
    load_graphs,
    load_info,
    makedirs,
    save_graphs,
    save_info,
)

backend = os.environ.get("DGLBACKEND", "pytorch")
Minjie Wang's avatar
Minjie Wang committed
34
35


HQ's avatar
HQ committed
36
37
def _pickle_load(pkl_file):
    if sys.version_info > (3, 0):
38
        return pkl.load(pkl_file, encoding="latin1")
HQ's avatar
HQ committed
39
40
41
    else:
        return pkl.load(pkl_file)

42

43
class CitationGraphDataset(DGLBuiltinDataset):
Mufei Li's avatar
Mufei Li committed
44
    r"""The citation graph dataset, including cora, citeseer and pubmeb.
45
46
47
48
49
    Nodes mean authors and edges mean citation relationships.

    Parameters
    -----------
    name: str
Mufei Li's avatar
Mufei Li committed
50
      name can be 'cora', 'citeseer' or 'pubmed'.
51
52
53
54
55
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
56
    verbose : bool
57
        Whether to print out progress information. Default: True.
58
    reverse_edge : bool
59
        Whether to add reverse edges in graph. Default: True.
60
61
62
63
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
64
65
    reorder : bool
        Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.
66
    """
67
    _urls = {
68
69
70
        "cora_v2": "dataset/cora_v2.zip",
        "citeseer": "dataset/citeseer.zip",
        "pubmed": "dataset/pubmed.zip",
71
72
    }

73
74
75
76
77
78
79
80
81
82
83
    def __init__(
        self,
        name,
        raw_dir=None,
        force_reload=False,
        verbose=True,
        reverse_edge=True,
        transform=None,
        reorder=False,
    ):
        assert name.lower() in ["cora", "citeseer", "pubmed"]
Mufei Li's avatar
Mufei Li committed
84
85
86

        # Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
        # for Cora, which is slightly different from the one used in the GCN paper
87
88
        if name.lower() == "cora":
            name = "cora_v2"
Mufei Li's avatar
Mufei Li committed
89

90
        url = _get_dgl_url(self._urls[name])
91
        self._reverse_edge = reverse_edge
92
        self._reorder = reorder
93

94
95
96
97
98
99
100
101
        super(CitationGraphDataset, self).__init__(
            name,
            url=url,
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )
102

103
    def process(self):
104
        """Loads input data from data directory and reorder graph for better locality
Minjie Wang's avatar
Minjie Wang committed
105
106
107
108
109
110
111
112
113
114
115
116

        ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
        ind.name.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
        ind.name.allx => the feature vectors of both labeled and unlabeled training instances
            (a superset of ind.name.x) as scipy.sparse.csr.csr_matrix object;
        ind.name.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
        ind.name.ty => the one-hot labels of the test instances as numpy.ndarray object;
        ind.name.ally => the labels for instances in ind.name.allx as numpy.ndarray object;
        ind.name.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
            object;
        ind.name.test.index => the indices of test instances in graph, for the inductive setting as list object.
        """
117
        root = self.raw_path
118
        objnames = ["x", "y", "tx", "ty", "allx", "ally", "graph"]
Minjie Wang's avatar
Minjie Wang committed
119
120
        objects = []
        for i in range(len(objnames)):
121
122
123
            with open(
                "{}/ind.{}.{}".format(root, self.name, objnames[i]), "rb"
            ) as f:
HQ's avatar
HQ committed
124
                objects.append(_pickle_load(f))
Minjie Wang's avatar
Minjie Wang committed
125
126

        x, y, tx, ty, allx, ally, graph = tuple(objects)
127
128
129
        test_idx_reorder = _parse_index_file(
            "{}/ind.{}.test.index".format(root, self.name)
        )
Minjie Wang's avatar
Minjie Wang committed
130
131
        test_idx_range = np.sort(test_idx_reorder)

132
        if self.name == "citeseer":
Minjie Wang's avatar
Minjie Wang committed
133
134
            # Fix citeseer dataset (there are some isolated nodes in the graph)
            # Find isolated nodes, add them as zero-vecs into the right position
135
136
137
            test_idx_range_full = range(
                min(test_idx_reorder), max(test_idx_reorder) + 1
            )
Minjie Wang's avatar
Minjie Wang committed
138
            tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
139
            tx_extended[test_idx_range - min(test_idx_range), :] = tx
Minjie Wang's avatar
Minjie Wang committed
140
141
            tx = tx_extended
            ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
142
            ty_extended[test_idx_range - min(test_idx_range), :] = ty
Minjie Wang's avatar
Minjie Wang committed
143
144
145
146
            ty = ty_extended

        features = sp.vstack((allx, tx)).tolil()
        features[test_idx_reorder, :] = features[test_idx_range, :]
147
148
149

        if self.reverse_edge:
            graph = nx.DiGraph(nx.from_dict_of_lists(graph))
150
            g = from_networkx(graph)
151
152
        else:
            graph = nx.Graph(nx.from_dict_of_lists(graph))
153
154
155
            edges = list(graph.edges())
            u, v = map(list, zip(*edges))
            g = dgl_graph((u, v))
Minjie Wang's avatar
Minjie Wang committed
156
157
158
159
160
161
162

        onehot_labels = np.vstack((ally, ty))
        onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]
        labels = np.argmax(onehot_labels, 1)

        idx_test = test_idx_range.tolist()
        idx_train = range(len(y))
163
        idx_val = range(len(y), len(y) + 500)
Minjie Wang's avatar
Minjie Wang committed
164

165
166
167
        train_mask = generate_mask_tensor(
            _sample_mask(idx_train, labels.shape[0])
        )
168
        val_mask = generate_mask_tensor(_sample_mask(idx_val, labels.shape[0]))
169
170
171
172
173
174
175
176
177
178
179
        test_mask = generate_mask_tensor(
            _sample_mask(idx_test, labels.shape[0])
        )

        g.ndata["train_mask"] = train_mask
        g.ndata["val_mask"] = val_mask
        g.ndata["test_mask"] = test_mask
        g.ndata["label"] = F.tensor(labels)
        g.ndata["feat"] = F.tensor(
            _preprocess_features(features), dtype=F.data_type_dict["float32"]
        )
180
        self._num_classes = onehot_labels.shape[1]
181
        self._labels = labels
182
183
        if self._reorder:
            self._g = reorder_graph(
184
185
186
187
188
                g,
                node_permute_algo="rcmk",
                edge_permute_algo="dst",
                store_ids=False,
            )
189
190
        else:
            self._g = g
191
192

        if self.verbose:
193
            print("Finished data loading and preprocessing.")
194
195
            print("  NumNodes: {}".format(self._g.num_nodes()))
            print("  NumEdges: {}".format(self._g.num_edges()))
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
            print("  NumFeats: {}".format(self._g.ndata["feat"].shape[1]))
            print("  NumClasses: {}".format(self.num_classes))
            print(
                "  NumTrainingSamples: {}".format(
                    F.nonzero_1d(self._g.ndata["train_mask"]).shape[0]
                )
            )
            print(
                "  NumValidationSamples: {}".format(
                    F.nonzero_1d(self._g.ndata["val_mask"]).shape[0]
                )
            )
            print(
                "  NumTestSamples: {}".format(
                    F.nonzero_1d(self._g.ndata["test_mask"]).shape[0]
                )
            )
213
214

    def has_cache(self):
215
216
217
        graph_path = os.path.join(self.save_path, self.save_name + ".bin")
        info_path = os.path.join(self.save_path, self.save_name + ".pkl")
        if os.path.exists(graph_path) and os.path.exists(info_path):
218
219
220
221
222
223
            return True

        return False

    def save(self):
        """save the graph list and the labels"""
224
225
        graph_path = os.path.join(self.save_path, self.save_name + ".bin")
        info_path = os.path.join(self.save_path, self.save_name + ".pkl")
226
        save_graphs(str(graph_path), self._g)
227
        save_info(str(info_path), {"num_classes": self.num_classes})
228
229

    def load(self):
230
231
        graph_path = os.path.join(self.save_path, self.save_name + ".bin")
        info_path = os.path.join(self.save_path, self.save_name + ".pkl")
232
233
234
        graphs, _ = load_graphs(str(graph_path))

        info = load_info(str(info_path))
235
236
237
        graph = graphs[0]
        self._g = graph
        # for compatability
238
        graph = graph.clone()
239
240
241
242
243
        graph.ndata.pop("train_mask")
        graph.ndata.pop("val_mask")
        graph.ndata.pop("test_mask")
        graph.ndata.pop("feat")
        graph.ndata.pop("label")
244
245
        graph = to_networkx(graph)

246
247
248
249
250
251
252
253
254
255
        self._num_classes = info["num_classes"]
        self._g.ndata["train_mask"] = generate_mask_tensor(
            F.asnumpy(self._g.ndata["train_mask"])
        )
        self._g.ndata["val_mask"] = generate_mask_tensor(
            F.asnumpy(self._g.ndata["val_mask"])
        )
        self._g.ndata["test_mask"] = generate_mask_tensor(
            F.asnumpy(self._g.ndata["test_mask"])
        )
256
257
258
        # hack for mxnet compatability

        if self.verbose:
259
260
            print("  NumNodes: {}".format(self._g.num_nodes()))
            print("  NumEdges: {}".format(self._g.num_edges()))
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
            print("  NumFeats: {}".format(self._g.ndata["feat"].shape[1]))
            print("  NumClasses: {}".format(self.num_classes))
            print(
                "  NumTrainingSamples: {}".format(
                    F.nonzero_1d(self._g.ndata["train_mask"]).shape[0]
                )
            )
            print(
                "  NumValidationSamples: {}".format(
                    F.nonzero_1d(self._g.ndata["val_mask"]).shape[0]
                )
            )
            print(
                "  NumTestSamples: {}".format(
                    F.nonzero_1d(self._g.ndata["test_mask"]).shape[0]
                )
            )
Minjie Wang's avatar
Minjie Wang committed
278

279
    def __getitem__(self, idx):
280
        assert idx == 0, "This dataset has only one graph"
281
282
283
284
        if self._transform is None:
            return self._g
        else:
            return self._transform(self._g)
285
286
287
288

    def __len__(self):
        return 1

289
290
    @property
    def save_name(self):
291
        return self.name + "_dgl_graph"
292
293
294

    @property
    def num_labels(self):
295
        deprecate_property("dataset.num_labels", "dataset.num_classes")
296
297
298
299
300
        return self.num_classes

    @property
    def num_classes(self):
        return self._num_classes
301
302
303
304
305

    """ Citation graph is used in many examples
        We preserve these properties for compatability.
    """

306
307
308
    @property
    def reverse_edge(self):
        return self._reverse_edge
309

310

Minjie Wang's avatar
Minjie Wang committed
311
312
def _preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
313
    rowsum = np.asarray(features.sum(1))
Minjie Wang's avatar
Minjie Wang committed
314
    r_inv = np.power(rowsum, -1).flatten()
315
    r_inv[np.isinf(r_inv)] = 0.0
Minjie Wang's avatar
Minjie Wang committed
316
317
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
318
    return np.asarray(features.todense())
Minjie Wang's avatar
Minjie Wang committed
319

320

Minjie Wang's avatar
Minjie Wang committed
321
322
323
324
325
326
327
def _parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

328

Minjie Wang's avatar
Minjie Wang committed
329
330
331
332
333
334
def _sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return mask

335

336
class CoraGraphDataset(CitationGraphDataset):
337
    r"""Cora citation network dataset.
338
339
340
341
342
343
344
345

    Nodes mean paper and edges mean citation
    relationships. Each node has a predefined
    feature with 1433 dimensions. The dataset is
    designed for the node classification task.
    The task is to predict the category of
    certain paper.

346
347
348
349
350
351
352
    Statistics:

    - Nodes: 2708
    - Edges: 10556
    - Number of Classes: 7
    - Label split:

353
        - Train: 140
354
355
        - Valid: 500
        - Test: 1000
Minjie Wang's avatar
Minjie Wang committed
356

357
358
359
360
361
362
363
    Parameters
    ----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
364
    verbose : bool
365
        Whether to print out progress information. Default: True.
366
    reverse_edge : bool
367
        Whether to add reverse edges in graph. Default: True.
368
369
370
371
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
372
373
    reorder : bool
        Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.
374
375
376

    Attributes
    ----------
377
    num_classes: int
378
379
380
381
382
383
384
385
386
        Number of label classes

    Notes
    -----
    The node feature is row-normalized.

    Examples
    --------
    >>> dataset = CoraGraphDataset()
387
    >>> g = dataset[0]
Mufei Li's avatar
Mufei Li committed
388
    >>> num_class = dataset.num_classes
389
390
391
392
393
394
395
396
397
398
399
    >>>
    >>> # get node feature
    >>> feat = g.ndata['feat']
    >>>
    >>> # get data split
    >>> train_mask = g.ndata['train_mask']
    >>> val_mask = g.ndata['val_mask']
    >>> test_mask = g.ndata['test_mask']
    >>>
    >>> # get labels
    >>> label = g.ndata['label']
Minjie Wang's avatar
Minjie Wang committed
400

401
    """
402

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    def __init__(
        self,
        raw_dir=None,
        force_reload=False,
        verbose=True,
        reverse_edge=True,
        transform=None,
        reorder=False,
    ):
        name = "cora"

        super(CoraGraphDataset, self).__init__(
            name,
            raw_dir,
            force_reload,
            verbose,
            reverse_edge,
            transform,
            reorder,
        )
423

424
    def __getitem__(self, idx):
425
426
427
428
429
430
431
432
433
        r"""Gets the graph object

        Parameters
        -----------
        idx: int
            Item index, CoraGraphDataset has only one graph object

        Return
        ------
434
435
        :class:`dgl.DGLGraph`

436
            graph structure, node features and labels.
437

438
            - ``ndata['train_mask']``: mask for training node set
439
440
441
442
            - ``ndata['val_mask']``: mask for validation node set
            - ``ndata['test_mask']``: mask for test node set
            - ``ndata['feat']``: node feature
            - ``ndata['label']``: ground truth labels
443
444
        """
        return super(CoraGraphDataset, self).__getitem__(idx)
445
446

    def __len__(self):
447
448
449
        r"""The number of graphs in the dataset."""
        return super(CoraGraphDataset, self).__len__()

450

451
class CiteseerGraphDataset(CitationGraphDataset):
452
    r"""Citeseer citation network dataset.
453
454
455
456
457
458
459
460

    Nodes mean scientific publications and edges
    mean citation relationships. Each node has a
    predefined feature with 3703 dimensions. The
    dataset is designed for the node classification
    task. The task is to predict the category of
    certain publication.

461
462
463
464
465
466
467
468
469
470
    Statistics:

    - Nodes: 3327
    - Edges: 9228
    - Number of Classes: 6
    - Label Split:

        - Train: 120
        - Valid: 500
        - Test: 1000
471

472
473
474
475
476
477
478
    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
479
    verbose : bool
480
        Whether to print out progress information. Default: True.
481
    reverse_edge : bool
482
        Whether to add reverse edges in graph. Default: True.
483
484
485
486
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
487
488
    reorder : bool
        Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.
489
490
491

    Attributes
    ----------
492
    num_classes: int
493
494
495
496
497
498
499
500
501
502
503
504
        Number of label classes

    Notes
    -----
    The node feature is row-normalized.

    In citeseer dataset, there are some isolated nodes in the graph.
    These isolated nodes are added as zero-vecs into the right position.

    Examples
    --------
    >>> dataset = CiteseerGraphDataset()
505
    >>> g = dataset[0]
Mufei Li's avatar
Mufei Li committed
506
    >>> num_class = dataset.num_classes
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    >>>
    >>> # get node feature
    >>> feat = g.ndata['feat']
    >>>
    >>> # get data split
    >>> train_mask = g.ndata['train_mask']
    >>> val_mask = g.ndata['val_mask']
    >>> test_mask = g.ndata['test_mask']
    >>>
    >>> # get labels
    >>> label = g.ndata['label']

    """

521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
    def __init__(
        self,
        raw_dir=None,
        force_reload=False,
        verbose=True,
        reverse_edge=True,
        transform=None,
        reorder=False,
    ):
        name = "citeseer"

        super(CiteseerGraphDataset, self).__init__(
            name,
            raw_dir,
            force_reload,
            verbose,
            reverse_edge,
            transform,
            reorder,
        )
541
542
543
544
545
546
547
548
549
550
551

    def __getitem__(self, idx):
        r"""Gets the graph object

        Parameters
        -----------
        idx: int
            Item index, CiteseerGraphDataset has only one graph object

        Return
        ------
552
553
        :class:`dgl.DGLGraph`

554
            graph structure, node features and labels.
555

556
            - ``ndata['train_mask']``: mask for training node set
557
558
559
560
            - ``ndata['val_mask']``: mask for validation node set
            - ``ndata['test_mask']``: mask for test node set
            - ``ndata['feat']``: node feature
            - ``ndata['label']``: ground truth labels
561
562
563
564
565
566
567
        """
        return super(CiteseerGraphDataset, self).__getitem__(idx)

    def __len__(self):
        r"""The number of graphs in the dataset."""
        return super(CiteseerGraphDataset, self).__len__()

568

569
class PubmedGraphDataset(CitationGraphDataset):
570
    r"""Pubmed citation network dataset.
571
572
573
574
575
576
577
578

    Nodes mean scientific publications and edges
    mean citation relationships. Each node has a
    predefined feature with 500 dimensions. The
    dataset is designed for the node classification
    task. The task is to predict the category of
    certain publication.

579
580
581
582
583
584
585
586
587
588
    Statistics:

    - Nodes: 19717
    - Edges: 88651
    - Number of Classes: 3
    - Label Split:

        - Train: 60
        - Valid: 500
        - Test: 1000
589
590
591
592
593
594
595
596

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
597
    verbose : bool
598
        Whether to print out progress information. Default: True.
599
    reverse_edge : bool
600
        Whether to add reverse edges in graph. Default: True.
601
602
603
604
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
605
606
    reorder : bool
        Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.
607
608
609

    Attributes
    ----------
610
    num_classes: int
611
612
613
614
615
616
617
618
619
        Number of label classes

    Notes
    -----
    The node feature is row-normalized.

    Examples
    --------
    >>> dataset = PubmedGraphDataset()
620
    >>> g = dataset[0]
Mufei Li's avatar
Mufei Li committed
621
    >>> num_class = dataset.num_of_class
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    >>>
    >>> # get node feature
    >>> feat = g.ndata['feat']
    >>>
    >>> # get data split
    >>> train_mask = g.ndata['train_mask']
    >>> val_mask = g.ndata['val_mask']
    >>> test_mask = g.ndata['test_mask']
    >>>
    >>> # get labels
    >>> label = g.ndata['label']

    """

636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
    def __init__(
        self,
        raw_dir=None,
        force_reload=False,
        verbose=True,
        reverse_edge=True,
        transform=None,
        reorder=False,
    ):
        name = "pubmed"

        super(PubmedGraphDataset, self).__init__(
            name,
            raw_dir,
            force_reload,
            verbose,
            reverse_edge,
            transform,
            reorder,
        )
656
657
658
659
660
661
662
663
664
665
666

    def __getitem__(self, idx):
        r"""Gets the graph object

        Parameters
        -----------
        idx: int
            Item index, PubmedGraphDataset has only one graph object

        Return
        ------
667
668
        :class:`dgl.DGLGraph`

669
            graph structure, node features and labels.
670

671
            - ``ndata['train_mask']``: mask for training node set
672
673
674
675
            - ``ndata['val_mask']``: mask for validation node set
            - ``ndata['test_mask']``: mask for test node set
            - ``ndata['feat']``: node feature
            - ``ndata['label']``: ground truth labels
676
677
678
679
680
681
682
        """
        return super(PubmedGraphDataset, self).__getitem__(idx)

    def __len__(self):
        r"""The number of graphs in the dataset."""
        return super(PubmedGraphDataset, self).__len__()

683
684
685
686
687
688
689
690

def load_cora(
    raw_dir=None,
    force_reload=False,
    verbose=True,
    reverse_edge=True,
    transform=None,
):
691
692
693
694
695
696
697
698
699
    """Get CoraGraphDataset

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
700
701
702
    verbose : bool
        Whether to print out progress information. Default: True.
    reverse_edge : bool
703
        Whether to add reverse edges in graph. Default: True.
704
705
706
707
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
708
709
710
711
712

    Return
    -------
    CoraGraphDataset
    """
713
714
715
    data = CoraGraphDataset(
        raw_dir, force_reload, verbose, reverse_edge, transform
    )
716
717
    return data

718
719
720
721
722
723
724
725

def load_citeseer(
    raw_dir=None,
    force_reload=False,
    verbose=True,
    reverse_edge=True,
    transform=None,
):
726
727
728
729
730
731
732
733
734
    """Get CiteseerGraphDataset

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
735
736
737
    verbose : bool
        Whether to print out progress information. Default: True.
    reverse_edge : bool
738
        Whether to add reverse edges in graph. Default: True.
739
740
741
742
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
743
744
745
746
747

    Return
    -------
    CiteseerGraphDataset
    """
748
749
750
    data = CiteseerGraphDataset(
        raw_dir, force_reload, verbose, reverse_edge, transform
    )
751
752
    return data

753
754
755
756
757
758
759
760

def load_pubmed(
    raw_dir=None,
    force_reload=False,
    verbose=True,
    reverse_edge=True,
    transform=None,
):
761
762
763
764
    """Get PubmedGraphDataset

    Parameters
    -----------
765
766
767
768
769
770
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose : bool
771
        Whether to print out progress information. Default: True.
772
    reverse_edge : bool
773
        Whether to add reverse edges in graph. Default: True.
774
775
776
777
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
778
779
780
781
782

    Return
    -------
    PubmedGraphDataset
    """
783
784
785
    data = PubmedGraphDataset(
        raw_dir, force_reload, verbose, reverse_edge, transform
    )
786
787
    return data

788

789
class CoraBinary(DGLBuiltinDataset):
HQ's avatar
HQ committed
790
791
792
793
794
795
796
    """A mini-dataset for binary classification task using Cora.

    After loaded, it has following members:

    graphs : list of :class:`~dgl.DGLGraph`
    pmpds : list of :class:`scipy.sparse.coo_matrix`
    labels : list of :class:`numpy.ndarray`
797
798
799
800
801
802
803
804
805
806

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
        Whether to print out progress information. Default: True.
807
808
809
810
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
HQ's avatar
HQ committed
811
    """
812
813
814
815
816
817
818
819
820
821
822
823
824
825

    def __init__(
        self, raw_dir=None, force_reload=False, verbose=True, transform=None
    ):
        name = "cora_binary"
        url = _get_dgl_url("dataset/cora_binary.zip")
        super(CoraBinary, self).__init__(
            name,
            url=url,
            raw_dir=raw_dir,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )
826
827
828

    def process(self):
        root = self.raw_path
HQ's avatar
HQ committed
829
830
        # load graphs
        self.graphs = []
831
        with open("{}/graphs.txt".format(root), "r") as f:
HQ's avatar
HQ committed
832
833
            elist = []
            for line in f.readlines():
834
                if line.startswith("graph"):
HQ's avatar
HQ committed
835
                    if len(elist) != 0:
836
                        self.graphs.append(dgl_graph(tuple(zip(*elist))))
HQ's avatar
HQ committed
837
838
                    elist = []
                else:
839
                    u, v = line.strip().split(" ")
HQ's avatar
HQ committed
840
841
                    elist.append((int(u), int(v)))
            if len(elist) != 0:
842
                self.graphs.append(dgl_graph(tuple(zip(*elist))))
843
        with open("{}/pmpds.pkl".format(root), "rb") as f:
HQ's avatar
HQ committed
844
845
            self.pmpds = _pickle_load(f)
        self.labels = []
846
        with open("{}/labels.txt".format(root), "r") as f:
HQ's avatar
HQ committed
847
848
            cur = []
            for line in f.readlines():
849
                if line.startswith("graph"):
HQ's avatar
HQ committed
850
                    if len(cur) != 0:
851
                        self.labels.append(np.asarray(cur))
HQ's avatar
HQ committed
852
853
854
855
                    cur = []
                else:
                    cur.append(int(line.strip()))
            if len(cur) != 0:
856
                self.labels.append(np.asarray(cur))
HQ's avatar
HQ committed
857
858
859
860
        # sanity check
        assert len(self.graphs) == len(self.pmpds)
        assert len(self.graphs) == len(self.labels)

861
    def has_cache(self):
862
        graph_path = os.path.join(self.save_path, self.save_name + ".bin")
863
864
865
866
867
868
869
        if os.path.exists(graph_path):
            return True

        return False

    def save(self):
        """save the graph list and the labels"""
870
        graph_path = os.path.join(self.save_path, self.save_name + ".bin")
871
872
        labels = {}
        for i, label in enumerate(self.labels):
873
            labels["{}".format(i)] = F.tensor(label)
874
875
        save_graphs(str(graph_path), self.graphs, labels)
        if self.verbose:
876
            print("Done saving data into cached files.")
877
878

    def load(self):
879
        graph_path = os.path.join(self.save_path, self.save_name + ".bin")
880
881
882
        self.graphs, labels = load_graphs(str(graph_path))

        self.labels = []
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
883
        for i in range(len(labels)):
884
            self.labels.append(F.asnumpy(labels["{}".format(i)]))
885
        # load pmpds under self.raw_path
886
        with open("{}/pmpds.pkl".format(self.raw_path), "rb") as f:
887
888
            self.pmpds = _pickle_load(f)
        if self.verbose:
889
            print("Done loading data into cached files.")
890
891
892
893
        # sanity check
        assert len(self.graphs) == len(self.pmpds)
        assert len(self.graphs) == len(self.labels)

HQ's avatar
HQ committed
894
895
896
897
    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, i):
898
899
900
901
902
903
904
905
906
907
908
909
        r"""Gets the idx-th sample.

        Parameters
        -----------
        idx : int
            The sample index.

        Returns
        -------
        (dgl.DGLGraph, scipy.sparse.coo_matrix, int)
            The graph, scipy sparse coo_matrix and its label.
        """
910
911
912
913
914
        if self._transform is None:
            g = self.graphs[i]
        else:
            g = self._transform(self.graphs[i])
        return (g, self.pmpds[i], self.labels[i])
HQ's avatar
HQ committed
915

916
917
    @property
    def save_name(self):
918
        return self.name + "_dgl_graph"
919

HQ's avatar
HQ committed
920
    @staticmethod
921
922
923
    def collate_fn(cur):
        graphs, pmpds, labels = zip(*cur)
        batched_graphs = batch.batch(graphs)
HQ's avatar
HQ committed
924
925
926
        batched_pmpds = sp.block_diag(pmpds)
        batched_labels = np.concatenate(labels, axis=0)
        return batched_graphs, batched_pmpds, batched_labels
927

928

929
930
def _normalize(mx):
    """Row-normalize sparse matrix"""
931
    rowsum = np.asarray(mx.sum(1))
932
    r_inv = np.power(rowsum, -1).flatten()
933
    r_inv[np.isinf(r_inv)] = 0.0
934
935
936
937
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

938

939
def _encode_onehot(labels):
940
    classes = list(sorted(set(labels)))
941
942
943
944
945
946
    classes_dict = {
        c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)
    }
    labels_onehot = np.asarray(
        list(map(classes_dict.get, labels)), dtype=np.int32
    )
947
    return labels_onehot