citation_graph.py 29.8 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
"""Cora, citeseer, pubmed dataset.

(lingfan): following dataset loading and preprocessing code from tkipf/gcn
https://github.com/tkipf/gcn/blob/master/gcn/utils.py
"""
6
7
from __future__ import absolute_import

Minjie Wang's avatar
Minjie Wang committed
8
9
10
11
12
13
import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import os, sys

14
15
16
17
from .utils import save_graphs, load_graphs, save_info, load_info, makedirs, _get_dgl_url
from .utils import generate_mask_tensor
from .utils import deprecate_property, deprecate_function
from .dgl_dataset import DGLBuiltinDataset
18
19
20
from .. import convert
from .. import batch
from .. import backend as F
21
from ..convert import graph as dgl_graph
22
from ..convert import from_networkx, to_networkx
Minjie Wang's avatar
Minjie Wang committed
23

24
backend = os.environ.get('DGLBACKEND', 'pytorch')
Minjie Wang's avatar
Minjie Wang committed
25

HQ's avatar
HQ committed
26
27
28
29
30
31
def _pickle_load(pkl_file):
    if sys.version_info > (3, 0):
        return pkl.load(pkl_file, encoding='latin1')
    else:
        return pkl.load(pkl_file)

32
class CitationGraphDataset(DGLBuiltinDataset):
Mufei Li's avatar
Mufei Li committed
33
    r"""The citation graph dataset, including cora, citeseer and pubmeb.
34
35
36
37
38
    Nodes mean authors and edges mean citation relationships.

    Parameters
    -----------
    name: str
Mufei Li's avatar
Mufei Li committed
39
      name can be 'cora', 'citeseer' or 'pubmed'.
40
41
42
43
44
45
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
46
47
48
        Whether to print out progress information. Default: True.
    reverse_edge: bool
        Whether to add reverse edges in graph. Default: True.
49
    """
50
51
52
53
54
55
    _urls = {
        'cora_v2' : 'dataset/cora_v2.zip',
        'citeseer' : 'dataset/citeseer.zip',
        'pubmed' : 'dataset/pubmed.zip',
    }

56
    def __init__(self, name, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
Mufei Li's avatar
Mufei Li committed
57
58
59
60
61
62
63
        assert name.lower() in ['cora', 'citeseer', 'pubmed']

        # Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
        # for Cora, which is slightly different from the one used in the GCN paper
        if name.lower() == 'cora':
            name = 'cora_v2'

64
        url = _get_dgl_url(self._urls[name])
65
66
        self._reverse_edge = reverse_edge

67
68
69
70
71
        super(CitationGraphDataset, self).__init__(name,
                                                   url=url,
                                                   raw_dir=raw_dir,
                                                   force_reload=force_reload,
                                                   verbose=verbose)
72

73
74
    def process(self):
        """Loads input data from data directory
Minjie Wang's avatar
Minjie Wang committed
75
76
77
78
79
80
81
82
83
84
85
86

        ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
        ind.name.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
        ind.name.allx => the feature vectors of both labeled and unlabeled training instances
            (a superset of ind.name.x) as scipy.sparse.csr.csr_matrix object;
        ind.name.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
        ind.name.ty => the one-hot labels of the test instances as numpy.ndarray object;
        ind.name.ally => the labels for instances in ind.name.allx as numpy.ndarray object;
        ind.name.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
            object;
        ind.name.test.index => the indices of test instances in graph, for the inductive setting as list object.
        """
87
        root = self.raw_path
Minjie Wang's avatar
Minjie Wang committed
88
89
90
91
        objnames = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
        objects = []
        for i in range(len(objnames)):
            with open("{}/ind.{}.{}".format(root, self.name, objnames[i]), 'rb') as f:
HQ's avatar
HQ committed
92
                objects.append(_pickle_load(f))
Minjie Wang's avatar
Minjie Wang committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        x, y, tx, ty, allx, ally, graph = tuple(objects)
        test_idx_reorder = _parse_index_file("{}/ind.{}.test.index".format(root, self.name))
        test_idx_range = np.sort(test_idx_reorder)

        if self.name == 'citeseer':
            # Fix citeseer dataset (there are some isolated nodes in the graph)
            # Find isolated nodes, add them as zero-vecs into the right position
            test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
            tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
            tx_extended[test_idx_range-min(test_idx_range), :] = tx
            tx = tx_extended
            ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
            ty_extended[test_idx_range-min(test_idx_range), :] = ty
            ty = ty_extended

        features = sp.vstack((allx, tx)).tolil()
        features[test_idx_reorder, :] = features[test_idx_range, :]
111
112
113
114
115

        if self.reverse_edge:
            graph = nx.DiGraph(nx.from_dict_of_lists(graph))
        else:
            graph = nx.Graph(nx.from_dict_of_lists(graph))
Minjie Wang's avatar
Minjie Wang committed
116
117
118
119
120
121
122
123
124

        onehot_labels = np.vstack((ally, ty))
        onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]
        labels = np.argmax(onehot_labels, 1)

        idx_test = test_idx_range.tolist()
        idx_train = range(len(y))
        idx_val = range(len(y), len(y)+500)

125
126
127
        train_mask = generate_mask_tensor(_sample_mask(idx_train, labels.shape[0]))
        val_mask = generate_mask_tensor(_sample_mask(idx_val, labels.shape[0]))
        test_mask = generate_mask_tensor(_sample_mask(idx_test, labels.shape[0]))
Minjie Wang's avatar
Minjie Wang committed
128

129
        self._graph = graph
130
        g = from_networkx(graph)
131

132
133
134
        g.ndata['train_mask'] = train_mask
        g.ndata['val_mask'] = val_mask
        g.ndata['test_mask'] = test_mask
135
136
        g.ndata['label'] = F.tensor(labels)
        g.ndata['feat'] = F.tensor(_preprocess_features(features), dtype=F.data_type_dict['float32'])
137
        self._num_classes = onehot_labels.shape[1]
138
139
140
141
142
143
144
145
        self._labels = labels
        self._g = g

        if self.verbose:
            print('Finished data loading and preprocessing.')
            print('  NumNodes: {}'.format(self._g.number_of_nodes()))
            print('  NumEdges: {}'.format(self._g.number_of_edges()))
            print('  NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
146
            print('  NumClasses: {}'.format(self.num_classes))
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            print('  NumTrainingSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
            print('  NumValidationSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['val_mask']).shape[0]))
            print('  NumTestSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['test_mask']).shape[0]))

    def has_cache(self):
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        info_path = os.path.join(self.save_path,
                                 self.save_name + '.pkl')
        if os.path.exists(graph_path) and \
            os.path.exists(info_path):
            return True

        return False

    def save(self):
        """save the graph list and the labels"""
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        info_path = os.path.join(self.save_path,
                                 self.save_name + '.pkl')
        save_graphs(str(graph_path), self._g)
172
        save_info(str(info_path), {'num_classes': self.num_classes})
173
174
175
176
177
178
179
180
181

    def load(self):
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        info_path = os.path.join(self.save_path,
                                 self.save_name + '.pkl')
        graphs, _ = load_graphs(str(graph_path))

        info = load_info(str(info_path))
182
183
184
        graph = graphs[0]
        self._g = graph
        # for compatability
185
        graph = graph.clone()
186
187
188
189
190
        graph.ndata.pop('train_mask')
        graph.ndata.pop('val_mask')
        graph.ndata.pop('test_mask')
        graph.ndata.pop('feat')
        graph.ndata.pop('label')
191
192
193
        graph = to_networkx(graph)
        self._graph = nx.DiGraph(graph)

194
        self._num_classes = info['num_classes']
195
196
197
        self._g.ndata['train_mask'] = generate_mask_tensor(F.asnumpy(self._g.ndata['train_mask']))
        self._g.ndata['val_mask'] = generate_mask_tensor(F.asnumpy(self._g.ndata['val_mask']))
        self._g.ndata['test_mask'] = generate_mask_tensor(F.asnumpy(self._g.ndata['test_mask']))
198
199
200
201
202
203
        # hack for mxnet compatability

        if self.verbose:
            print('  NumNodes: {}'.format(self._g.number_of_nodes()))
            print('  NumEdges: {}'.format(self._g.number_of_edges()))
            print('  NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
204
            print('  NumClasses: {}'.format(self.num_classes))
205
206
207
208
209
210
            print('  NumTrainingSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
            print('  NumValidationSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['val_mask']).shape[0]))
            print('  NumTestSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['test_mask']).shape[0]))
Minjie Wang's avatar
Minjie Wang committed
211

212
    def __getitem__(self, idx):
213
        assert idx == 0, "This dataset has only one graph"
214
        return self._g
215
216
217
218

    def __len__(self):
        return 1

219
220
221
222
223
224
    @property
    def save_name(self):
        return self.name + '_dgl_graph'

    @property
    def num_labels(self):
225
226
227
228
229
230
        deprecate_property('dataset.num_labels', 'dataset.num_classes')
        return self.num_classes

    @property
    def num_classes(self):
        return self._num_classes
231
232
233
234
235
236

    """ Citation graph is used in many examples
        We preserve these properties for compatability.
    """
    @property
    def graph(self):
chwan-rice's avatar
chwan-rice committed
237
        deprecate_property('dataset.graph', 'dataset[0]')
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        return self._graph

    @property
    def train_mask(self):
        deprecate_property('dataset.train_mask', 'g.ndata[\'train_mask\']')
        return F.asnumpy(self._g.ndata['train_mask'])

    @property
    def val_mask(self):
        deprecate_property('dataset.val_mask', 'g.ndata[\'val_mask\']')
        return F.asnumpy(self._g.ndata['val_mask'])

    @property
    def test_mask(self):
        deprecate_property('dataset.test_mask', 'g.ndata[\'test_mask\']')
        return F.asnumpy(self._g.ndata['test_mask'])

    @property
    def labels(self):
        deprecate_property('dataset.label', 'g.ndata[\'label\']')
        return F.asnumpy(self._g.ndata['label'])

    @property
    def features(self):
        deprecate_property('dataset.feat', 'g.ndata[\'feat\']')
        return self._g.ndata['feat']

265
266
267
268
269
    @property
    def reverse_edge(self):
        return self._reverse_edge
    

Minjie Wang's avatar
Minjie Wang committed
270
271
def _preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
272
    rowsum = np.asarray(features.sum(1))
Minjie Wang's avatar
Minjie Wang committed
273
274
275
276
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
277
    return np.asarray(features.todense())
Minjie Wang's avatar
Minjie Wang committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291

def _parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

def _sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return mask

292
293
294
295
class CoraGraphDataset(CitationGraphDataset):
    r""" Cora citation network dataset.

    .. deprecated:: 0.5.0
296
297
298

        - ``graph`` is deprecated, it is replaced by:

299
300
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
301
302
303

        - ``train_mask`` is deprecated, it is replaced by:

304
305
306
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
            >>> train_mask = graph.ndata['train_mask']
307
308
309

        - ``val_mask`` is deprecated, it is replaced by:

310
311
312
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
            >>> val_mask = graph.ndata['val_mask']
313
314
315

        - ``test_mask`` is deprecated, it is replaced by:

316
317
318
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
            >>> test_mask = graph.ndata['test_mask']
319
320
321

        - ``labels`` is deprecated, it is replaced by:

322
323
324
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
            >>> labels = graph.ndata['label']
325
326
327

        - ``feat`` is deprecated, it is replaced by:

328
329
330
331
332
333
334
335
336
337
338
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
            >>> feat = graph.ndata['feat']

    Nodes mean paper and edges mean citation
    relationships. Each node has a predefined
    feature with 1433 dimensions. The dataset is
    designed for the node classification task.
    The task is to predict the category of
    certain paper.

339
340
341
342
343
344
345
    Statistics:

    - Nodes: 2708
    - Edges: 10556
    - Number of Classes: 7
    - Label split:

346
        - Train: 140
347
348
        - Valid: 500
        - Test: 1000
Minjie Wang's avatar
Minjie Wang committed
349

350
351
352
353
354
355
356
357
    Parameters
    ----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
358
        Whether to print out progress information. Default: True.
359
360
    reverse_edge: bool
        Whether to add reverse edges in graph. Default: True.
361
362
363

    Attributes
    ----------
364
    num_classes: int
365
366
367
        Number of label classes
    graph: networkx.DiGraph
        Graph structure
368
    train_mask: numpy.ndarray
369
        Mask of training nodes
370
    val_mask: numpy.ndarray
371
        Mask of validation nodes
372
    test_mask: numpy.ndarray
373
        Mask of test nodes
374
    labels: numpy.ndarray
375
376
377
378
379
380
381
382
383
384
385
        Ground truth labels of each node
    features: Tensor
        Node features

    Notes
    -----
    The node feature is row-normalized.

    Examples
    --------
    >>> dataset = CoraGraphDataset()
386
    >>> g = dataset[0]
387
    >>> num_class = g.num_classes
388
389
390
391
392
393
394
395
396
397
398
399
400
    >>>
    >>> # get node feature
    >>> feat = g.ndata['feat']
    >>>
    >>> # get data split
    >>> train_mask = g.ndata['train_mask']
    >>> val_mask = g.ndata['val_mask']
    >>> test_mask = g.ndata['test_mask']
    >>>
    >>> # get labels
    >>> label = g.ndata['label']
    >>>
    >>> # Train, Validation and Test
Minjie Wang's avatar
Minjie Wang committed
401

402
    """
403
    def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
404
        name = 'cora'
405

406
        super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
407

408
    def __getitem__(self, idx):
409
410
411
412
413
414
415
416
417
        r"""Gets the graph object

        Parameters
        -----------
        idx: int
            Item index, CoraGraphDataset has only one graph object

        Return
        ------
418
419
        :class:`dgl.DGLGraph`

420
            graph structure, node features and labels.
421
422
423
424
425
426

            - ``ndata['train_mask']``: mask for training node set
            - ``ndata['val_mask']``: mask for validation node set
            - ``ndata['test_mask']``: mask for test node set
            - ``ndata['feat']``: node feature
            - ``ndata['label']``: ground truth labels
427
428
        """
        return super(CoraGraphDataset, self).__getitem__(idx)
429
430

    def __len__(self):
431
432
433
434
435
436
437
        r"""The number of graphs in the dataset."""
        return super(CoraGraphDataset, self).__len__()

class CiteseerGraphDataset(CitationGraphDataset):
    r""" Citeseer citation network dataset.

    .. deprecated:: 0.5.0
438

439
        - ``graph`` is deprecated, it is replaced by:
440

441
442
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
443

444
        - ``train_mask`` is deprecated, it is replaced by:
445

446
447
448
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
            >>> train_mask = graph.ndata['train_mask']
449

450
        - ``val_mask`` is deprecated, it is replaced by:
451

452
453
454
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
            >>> val_mask = graph.ndata['val_mask']
455

456
        - ``test_mask`` is deprecated, it is replaced by:
457

458
459
460
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
            >>> test_mask = graph.ndata['test_mask']
461

462
        - ``labels`` is deprecated, it is replaced by:
463

464
465
466
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
            >>> labels = graph.ndata['label']
467

468
        - ``feat`` is deprecated, it is replaced by:
469

470
471
472
473
474
475
476
477
478
479
480
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
            >>> feat = graph.ndata['feat']

    Nodes mean scientific publications and edges
    mean citation relationships. Each node has a
    predefined feature with 3703 dimensions. The
    dataset is designed for the node classification
    task. The task is to predict the category of
    certain publication.

481
482
483
484
485
486
487
488
489
490
    Statistics:

    - Nodes: 3327
    - Edges: 9228
    - Number of Classes: 6
    - Label Split:

        - Train: 120
        - Valid: 500
        - Test: 1000
491

492
493
494
495
496
497
498
499
    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
500
        Whether to print out progress information. Default: True.
501
502
    reverse_edge: bool
        Whether to add reverse edges in graph. Default: True.
503
504
505

    Attributes
    ----------
506
    num_classes: int
507
508
509
        Number of label classes
    graph: networkx.DiGraph
        Graph structure
510
    train_mask: numpy.ndarray
511
        Mask of training nodes
512
    val_mask: numpy.ndarray
513
        Mask of validation nodes
514
    test_mask: numpy.ndarray
515
        Mask of test nodes
516
    labels: numpy.ndarray
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        Ground truth labels of each node
    features: Tensor
        Node features

    Notes
    -----
    The node feature is row-normalized.

    In citeseer dataset, there are some isolated nodes in the graph.
    These isolated nodes are added as zero-vecs into the right position.

    Examples
    --------
    >>> dataset = CiteseerGraphDataset()
531
    >>> g = dataset[0]
532
    >>> num_class = g.num_classes
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    >>>
    >>> # get node feature
    >>> feat = g.ndata['feat']
    >>>
    >>> # get data split
    >>> train_mask = g.ndata['train_mask']
    >>> val_mask = g.ndata['val_mask']
    >>> test_mask = g.ndata['test_mask']
    >>>
    >>> # get labels
    >>> label = g.ndata['label']
    >>>
    >>> # Train, Validation and Test

    """
548
    def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
549
550
        name = 'citeseer'

551
        super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
552
553
554
555
556
557
558
559
560
561
562

    def __getitem__(self, idx):
        r"""Gets the graph object

        Parameters
        -----------
        idx: int
            Item index, CiteseerGraphDataset has only one graph object

        Return
        ------
563
564
        :class:`dgl.DGLGraph`

565
            graph structure, node features and labels.
566
567
568
569
570
571

            - ``ndata['train_mask']``: mask for training node set
            - ``ndata['val_mask']``: mask for validation node set
            - ``ndata['test_mask']``: mask for test node set
            - ``ndata['feat']``: node feature
            - ``ndata['label']``: ground truth labels
572
573
574
575
576
577
578
579
580
581
582
        """
        return super(CiteseerGraphDataset, self).__getitem__(idx)

    def __len__(self):
        r"""The number of graphs in the dataset."""
        return super(CiteseerGraphDataset, self).__len__()

class PubmedGraphDataset(CitationGraphDataset):
    r""" Pubmed citation network dataset.

    .. deprecated:: 0.5.0
583

584
        - ``graph`` is deprecated, it is replaced by:
585

586
587
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
588

589
        - ``train_mask`` is deprecated, it is replaced by:
590

591
592
593
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
            >>> train_mask = graph.ndata['train_mask']
594

595
        - ``val_mask`` is deprecated, it is replaced by:
596

597
598
599
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
            >>> val_mask = graph.ndata['val_mask']
600

601
        - ``test_mask`` is deprecated, it is replaced by:
602

603
604
605
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
            >>> test_mask = graph.ndata['test_mask']
606

607
        - ``labels`` is deprecated, it is replaced by:
608

609
610
611
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
            >>> labels = graph.ndata['label']
612

613
        - ``feat`` is deprecated, it is replaced by:
614

615
616
617
618
619
620
621
622
623
624
625
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
            >>> feat = graph.ndata['feat']

    Nodes mean scientific publications and edges
    mean citation relationships. Each node has a
    predefined feature with 500 dimensions. The
    dataset is designed for the node classification
    task. The task is to predict the category of
    certain publication.

626
627
628
629
630
631
632
633
634
635
    Statistics:

    - Nodes: 19717
    - Edges: 88651
    - Number of Classes: 3
    - Label Split:

        - Train: 60
        - Valid: 500
        - Test: 1000
636
637
638
639
640
641
642
643
644

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
645
        Whether to print out progress information. Default: True.
646
647
    reverse_edge: bool
        Whether to add reverse edges in graph. Default: True.
648
649
650

    Attributes
    ----------
651
    num_classes: int
652
653
654
        Number of label classes
    graph: networkx.DiGraph
        Graph structure
655
    train_mask: numpy.ndarray
656
        Mask of training nodes
657
    val_mask: numpy.ndarray
658
        Mask of validation nodes
659
    test_mask: numpy.ndarray
660
        Mask of test nodes
661
    labels: numpy.ndarray
662
663
664
665
666
667
668
669
670
671
672
        Ground truth labels of each node
    features: Tensor
        Node features

    Notes
    -----
    The node feature is row-normalized.

    Examples
    --------
    >>> dataset = PubmedGraphDataset()
673
    >>> g = dataset[0]
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
    >>> num_class = g.num_of_class
    >>>
    >>> # get node feature
    >>> feat = g.ndata['feat']
    >>>
    >>> # get data split
    >>> train_mask = g.ndata['train_mask']
    >>> val_mask = g.ndata['val_mask']
    >>> test_mask = g.ndata['test_mask']
    >>>
    >>> # get labels
    >>> label = g.ndata['label']
    >>>
    >>> # Train, Validation and Test

    """
690
    def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
691
692
        name = 'pubmed'

693
        super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
694
695
696
697
698
699
700
701
702
703
704

    def __getitem__(self, idx):
        r"""Gets the graph object

        Parameters
        -----------
        idx: int
            Item index, PubmedGraphDataset has only one graph object

        Return
        ------
705
706
        :class:`dgl.DGLGraph`

707
            graph structure, node features and labels.
708
709
710
711
712
713

            - ``ndata['train_mask']``: mask for training node set
            - ``ndata['val_mask']``: mask for validation node set
            - ``ndata['test_mask']``: mask for test node set
            - ``ndata['feat']``: node feature
            - ``ndata['label']``: ground truth labels
714
715
716
717
718
719
720
        """
        return super(PubmedGraphDataset, self).__getitem__(idx)

    def __len__(self):
        r"""The number of graphs in the dataset."""
        return super(PubmedGraphDataset, self).__len__()

721
def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
722
723
724
725
726
727
728
729
730
731
732
    """Get CoraGraphDataset

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
    Whether to print out progress information. Default: True.
733
734
    reverse_edge: bool
        Whether to add reverse edges in graph. Default: True.
735
736
737
738
739

    Return
    -------
    CoraGraphDataset
    """
740
    data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
741
742
    return data

743
def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
744
745
746
747
748
749
750
751
752
753
754
    """Get CiteseerGraphDataset

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
    Whether to print out progress information. Default: True.
755
756
    reverse_edge: bool
        Whether to add reverse edges in graph. Default: True.
757
758
759
760
761

    Return
    -------
    CiteseerGraphDataset
    """
762
    data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
763
764
    return data

765
def load_pubmed(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
766
767
768
769
770
771
772
773
774
775
776
    """Get PubmedGraphDataset

    Parameters
    -----------
        raw_dir : str
            Raw file directory to download/contains the input data directory.
            Default: ~/.dgl/
        force_reload : bool
            Whether to reload the dataset. Default: False
        verbose: bool
        Whether to print out progress information. Default: True.
777
778
    reverse_edge: bool
        Whether to add reverse edges in graph. Default: True.
779
780
781
782
783

    Return
    -------
    PubmedGraphDataset
    """
784
    data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
785
786
787
    return data

class CoraBinary(DGLBuiltinDataset):
HQ's avatar
HQ committed
788
789
790
791
792
793
794
    """A mini-dataset for binary classification task using Cora.

    After loaded, it has following members:

    graphs : list of :class:`~dgl.DGLGraph`
    pmpds : list of :class:`scipy.sparse.coo_matrix`
    labels : list of :class:`numpy.ndarray`
795
796
797
798
799
800
801
802
803
804

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
        Whether to print out progress information. Default: True.
HQ's avatar
HQ committed
805
    """
806
807
808
809
810
811
812
813
814
815
816
    def __init__(self, raw_dir=None, force_reload=False, verbose=True):
        name = 'cora_binary'
        url = _get_dgl_url('dataset/cora_binary.zip')
        super(CoraBinary, self).__init__(name,
                                         url=url,
                                         raw_dir=raw_dir,
                                         force_reload=force_reload,
                                         verbose=verbose)

    def process(self):
        root = self.raw_path
HQ's avatar
HQ committed
817
818
819
820
821
822
823
        # load graphs
        self.graphs = []
        with open("{}/graphs.txt".format(root), 'r') as f:
            elist = []
            for line in f.readlines():
                if line.startswith('graph'):
                    if len(elist) != 0:
824
                        self.graphs.append(dgl_graph(tuple(zip(*elist))))
HQ's avatar
HQ committed
825
826
827
828
829
                    elist = []
                else:
                    u, v = line.strip().split(' ')
                    elist.append((int(u), int(v)))
            if len(elist) != 0:
830
                self.graphs.append(dgl_graph(tuple(zip(*elist))))
HQ's avatar
HQ committed
831
832
833
834
835
836
837
838
        with open("{}/pmpds.pkl".format(root), 'rb') as f:
            self.pmpds = _pickle_load(f)
        self.labels = []
        with open("{}/labels.txt".format(root), 'r') as f:
            cur = []
            for line in f.readlines():
                if line.startswith('graph'):
                    if len(cur) != 0:
839
                        self.labels.append(np.asarray(cur))
HQ's avatar
HQ committed
840
841
842
843
                    cur = []
                else:
                    cur.append(int(line.strip()))
            if len(cur) != 0:
844
                self.labels.append(np.asarray(cur))
HQ's avatar
HQ committed
845
846
847
848
        # sanity check
        assert len(self.graphs) == len(self.pmpds)
        assert len(self.graphs) == len(self.labels)

849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
    def has_cache(self):
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        if os.path.exists(graph_path):
            return True

        return False

    def save(self):
        """save the graph list and the labels"""
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        labels = {}
        for i, label in enumerate(self.labels):
            labels['{}'.format(i)] = F.tensor(label)
        save_graphs(str(graph_path), self.graphs, labels)
        if self.verbose:
            print('Done saving data into cached files.')

    def load(self):
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        self.graphs, labels = load_graphs(str(graph_path))

        self.labels = []
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
874
875
        for i in range(len(labels)):
            self.labels.append(F.asnumpy(labels['{}'.format(i)]))
876
877
878
879
880
881
882
883
884
        # load pmpds under self.raw_path
        with open("{}/pmpds.pkl".format(self.raw_path), 'rb') as f:
            self.pmpds = _pickle_load(f)
        if self.verbose:
            print('Done loading data into cached files.')
        # sanity check
        assert len(self.graphs) == len(self.pmpds)
        assert len(self.graphs) == len(self.labels)

HQ's avatar
HQ committed
885
886
887
888
    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, i):
889
890
891
892
893
894
895
896
897
898
899
900
        r"""Gets the idx-th sample.

        Parameters
        -----------
        idx : int
            The sample index.

        Returns
        -------
        (dgl.DGLGraph, scipy.sparse.coo_matrix, int)
            The graph, scipy sparse coo_matrix and its label.
        """
HQ's avatar
HQ committed
901
902
        return (self.graphs[i], self.pmpds[i], self.labels[i])

903
904
905
906
    @property
    def save_name(self):
        return self.name + '_dgl_graph'

HQ's avatar
HQ committed
907
    @staticmethod
908
909
910
    def collate_fn(cur):
        graphs, pmpds, labels = zip(*cur)
        batched_graphs = batch.batch(graphs)
HQ's avatar
HQ committed
911
912
913
        batched_pmpds = sp.block_diag(pmpds)
        batched_labels = np.concatenate(labels, axis=0)
        return batched_graphs, batched_pmpds, batched_labels
914
915
916

def _normalize(mx):
    """Row-normalize sparse matrix"""
917
    rowsum = np.asarray(mx.sum(1))
918
    r_inv = np.power(rowsum, -1).flatten()
Zhengwei's avatar
Zhengwei committed
919
    r_inv[np.isinf(r_inv)] = 0.
920
921
922
923
924
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def _encode_onehot(labels):
925
    classes = list(sorted(set(labels)))
926
927
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
928
929
    labels_onehot = np.asarray(list(map(classes_dict.get, labels)),
                               dtype=np.int32)
930
    return labels_onehot