"src/kernel/vscode:/vscode.git/clone" did not exist on "bcd33e0ac0d76cb5e5fa289eebec1e21efde88a3"
citation_graph.py 28.5 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
"""Cora, citeseer, pubmed dataset.

(lingfan): following dataset loading and preprocessing code from tkipf/gcn
https://github.com/tkipf/gcn/blob/master/gcn/utils.py
"""
6
7
from __future__ import absolute_import

Minjie Wang's avatar
Minjie Wang committed
8
9
10
11
12
13
import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import os, sys

14
15
16
17
from .utils import save_graphs, load_graphs, save_info, load_info, makedirs, _get_dgl_url
from .utils import generate_mask_tensor
from .utils import deprecate_property, deprecate_function
from .dgl_dataset import DGLBuiltinDataset
18
19
20
from .. import convert
from .. import batch
from .. import backend as F
21
22
from ..convert import graph as dgl_graph
from ..convert import to_networkx
Minjie Wang's avatar
Minjie Wang committed
23

24
backend = os.environ.get('DGLBACKEND', 'pytorch')
Minjie Wang's avatar
Minjie Wang committed
25

HQ's avatar
HQ committed
26
27
28
29
30
31
def _pickle_load(pkl_file):
    if sys.version_info > (3, 0):
        return pkl.load(pkl_file, encoding='latin1')
    else:
        return pkl.load(pkl_file)

32
class CitationGraphDataset(DGLBuiltinDataset):
Mufei Li's avatar
Mufei Li committed
33
    r"""The citation graph dataset, including cora, citeseer and pubmeb.
34
35
36
37
38
    Nodes mean authors and edges mean citation relationships.

    Parameters
    -----------
    name: str
Mufei Li's avatar
Mufei Li committed
39
      name can be 'cora', 'citeseer' or 'pubmed'.
40
41
42
43
44
45
46
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
      Whether to print out progress information. Default: True.
47
    """
48
49
50
51
52
53
54
    _urls = {
        'cora_v2' : 'dataset/cora_v2.zip',
        'citeseer' : 'dataset/citeseer.zip',
        'pubmed' : 'dataset/pubmed.zip',
    }

    def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
Mufei Li's avatar
Mufei Li committed
55
56
57
58
59
60
61
        assert name.lower() in ['cora', 'citeseer', 'pubmed']

        # Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
        # for Cora, which is slightly different from the one used in the GCN paper
        if name.lower() == 'cora':
            name = 'cora_v2'

62
63
64
65
66
67
        url = _get_dgl_url(self._urls[name])
        super(CitationGraphDataset, self).__init__(name,
                                                   url=url,
                                                   raw_dir=raw_dir,
                                                   force_reload=force_reload,
                                                   verbose=verbose)
68

69
70
    def process(self):
        """Loads input data from data directory
Minjie Wang's avatar
Minjie Wang committed
71
72
73
74
75
76
77
78
79
80
81
82

        ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
        ind.name.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
        ind.name.allx => the feature vectors of both labeled and unlabeled training instances
            (a superset of ind.name.x) as scipy.sparse.csr.csr_matrix object;
        ind.name.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
        ind.name.ty => the one-hot labels of the test instances as numpy.ndarray object;
        ind.name.ally => the labels for instances in ind.name.allx as numpy.ndarray object;
        ind.name.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
            object;
        ind.name.test.index => the indices of test instances in graph, for the inductive setting as list object.
        """
83
        root = self.raw_path
Minjie Wang's avatar
Minjie Wang committed
84
85
86
87
        objnames = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
        objects = []
        for i in range(len(objnames)):
            with open("{}/ind.{}.{}".format(root, self.name, objnames[i]), 'rb') as f:
HQ's avatar
HQ committed
88
                objects.append(_pickle_load(f))
Minjie Wang's avatar
Minjie Wang committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

        x, y, tx, ty, allx, ally, graph = tuple(objects)
        test_idx_reorder = _parse_index_file("{}/ind.{}.test.index".format(root, self.name))
        test_idx_range = np.sort(test_idx_reorder)

        if self.name == 'citeseer':
            # Fix citeseer dataset (there are some isolated nodes in the graph)
            # Find isolated nodes, add them as zero-vecs into the right position
            test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
            tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
            tx_extended[test_idx_range-min(test_idx_range), :] = tx
            tx = tx_extended
            ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
            ty_extended[test_idx_range-min(test_idx_range), :] = ty
            ty = ty_extended

        features = sp.vstack((allx, tx)).tolil()
        features[test_idx_reorder, :] = features[test_idx_range, :]
        graph = nx.DiGraph(nx.from_dict_of_lists(graph))

        onehot_labels = np.vstack((ally, ty))
        onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]
        labels = np.argmax(onehot_labels, 1)

        idx_test = test_idx_range.tolist()
        idx_train = range(len(y))
        idx_val = range(len(y), len(y)+500)

        train_mask = _sample_mask(idx_train, labels.shape[0])
        val_mask = _sample_mask(idx_val, labels.shape[0])
        test_mask = _sample_mask(idx_test, labels.shape[0])

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        self._graph = graph
        g = dgl_graph(graph)

        g.ndata['train_mask'] = generate_mask_tensor(train_mask)
        g.ndata['val_mask'] = generate_mask_tensor(val_mask)
        g.ndata['test_mask'] = generate_mask_tensor(test_mask)
        g.ndata['label'] = F.tensor(labels)
        g.ndata['feat'] = F.tensor(_preprocess_features(features), dtype=F.data_type_dict['float32'])
        self._num_labels = onehot_labels.shape[1]
        self._labels = labels
        self._g = g

        if self.verbose:
            print('Finished data loading and preprocessing.')
            print('  NumNodes: {}'.format(self._g.number_of_nodes()))
            print('  NumEdges: {}'.format(self._g.number_of_edges()))
            print('  NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
            print('  NumClasses: {}'.format(self.num_labels))
            print('  NumTrainingSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
            print('  NumValidationSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['val_mask']).shape[0]))
            print('  NumTestSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['test_mask']).shape[0]))

    def has_cache(self):
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        info_path = os.path.join(self.save_path,
                                 self.save_name + '.pkl')
        if os.path.exists(graph_path) and \
            os.path.exists(info_path):
            return True

        return False

    def save(self):
        """save the graph list and the labels"""
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        info_path = os.path.join(self.save_path,
                                 self.save_name + '.pkl')
        save_graphs(str(graph_path), self._g)
        save_info(str(info_path), {'num_labels': self.num_labels})

    def load(self):
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        info_path = os.path.join(self.save_path,
                                 self.save_name + '.pkl')
        graphs, _ = load_graphs(str(graph_path))

        info = load_info(str(info_path))
        self._g = graphs[0]
        graph = graph.clone()
        graph.pop('train_mask')
        graph.pop('val_mask')
        graph.pop('test_mask')
        graph.pop('feat')
        graph.pop('label')
        graph = to_networkx(graph)
        self._graph = nx.DiGraph(graph)

        self._num_labels = info['num_labels']
        self._g.ndata['train_mask'] = generate_mask_tensor(self._g.ndata['train_mask'].numpy())
        self._g.ndata['val_mask'] = generate_mask_tensor(self._g.ndata['val_mask'].numpy())
        self._g.ndata['test_mask'] = generate_mask_tensor(self._g.ndata['test_mask'].numpy())
        # hack for mxnet compatability

        if self.verbose:
            print('  NumNodes: {}'.format(self._g.number_of_nodes()))
            print('  NumEdges: {}'.format(self._g.number_of_edges()))
            print('  NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
            print('  NumClasses: {}'.format(self.num_labels))
            print('  NumTrainingSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
            print('  NumValidationSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['val_mask']).shape[0]))
            print('  NumTestSamples: {}'.format(
                F.nonzero_1d(self._g.ndata['test_mask']).shape[0]))
Minjie Wang's avatar
Minjie Wang committed
201

202
    def __getitem__(self, idx):
203
        assert idx == 0, "This dataset has only one graph"
204
        return self._g
205
206
207
208

    def __len__(self):
        return 1

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    @property
    def save_name(self):
        return self.name + '_dgl_graph'

    @property
    def num_labels(self):
        return self._num_labels

    """ Citation graph is used in many examples
        We preserve these properties for compatability.
    """
    @property
    def graph(self):
        deprecate_property('dataset.graph', 'dataset.g')
        return self._graph

    @property
    def train_mask(self):
        deprecate_property('dataset.train_mask', 'g.ndata[\'train_mask\']')
        return F.asnumpy(self._g.ndata['train_mask'])

    @property
    def val_mask(self):
        deprecate_property('dataset.val_mask', 'g.ndata[\'val_mask\']')
        return F.asnumpy(self._g.ndata['val_mask'])

    @property
    def test_mask(self):
        deprecate_property('dataset.test_mask', 'g.ndata[\'test_mask\']')
        return F.asnumpy(self._g.ndata['test_mask'])

    @property
    def labels(self):
        deprecate_property('dataset.label', 'g.ndata[\'label\']')
        return F.asnumpy(self._g.ndata['label'])

    @property
    def features(self):
        deprecate_property('dataset.feat', 'g.ndata[\'feat\']')
        return self._g.ndata['feat']

Minjie Wang's avatar
Minjie Wang committed
250
251
def _preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
252
    rowsum = np.asarray(features.sum(1))
Minjie Wang's avatar
Minjie Wang committed
253
254
255
256
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
257
    return np.asarray(features.todense())
Minjie Wang's avatar
Minjie Wang committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271

def _parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

def _sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return mask

272
273
274
275
class CoraGraphDataset(CitationGraphDataset):
    r""" Cora citation network dataset.

    .. deprecated:: 0.5.0
276
277
278

        - ``graph`` is deprecated, it is replaced by:

279
280
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
281
282
283

        - ``train_mask`` is deprecated, it is replaced by:

284
285
286
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
            >>> train_mask = graph.ndata['train_mask']
287
288
289

        - ``val_mask`` is deprecated, it is replaced by:

290
291
292
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
            >>> val_mask = graph.ndata['val_mask']
293
294
295

        - ``test_mask`` is deprecated, it is replaced by:

296
297
298
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
            >>> test_mask = graph.ndata['test_mask']
299
300
301

        - ``labels`` is deprecated, it is replaced by:

302
303
304
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
            >>> labels = graph.ndata['label']
305
306
307

        - ``feat`` is deprecated, it is replaced by:

308
309
310
311
312
313
314
315
316
317
318
            >>> dataset = CoraGraphDataset()
            >>> graph = dataset[0]
            >>> feat = graph.ndata['feat']

    Nodes mean paper and edges mean citation
    relationships. Each node has a predefined
    feature with 1433 dimensions. The dataset is
    designed for the node classification task.
    The task is to predict the category of
    certain paper.

319
320
321
322
323
324
325
326
327
328
    Statistics:

    - Nodes: 2708
    - Edges: 10556
    - Number of Classes: 7
    - Label split:

        - Train: 140 
        - Valid: 500
        - Test: 1000
Minjie Wang's avatar
Minjie Wang committed
329

330
331
332
333
334
335
336
337
    Parameters
    ----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
338
        Whether to print out progress information. Default: True.
339
340
341
342
343
344
345

    Attributes
    ----------
    num_labels: int
        Number of label classes
    graph: networkx.DiGraph
        Graph structure
346
    train_mask: numpy.ndarray
347
        Mask of training nodes
348
    val_mask: numpy.ndarray
349
        Mask of validation nodes
350
    test_mask: numpy.ndarray
351
        Mask of test nodes
352
    labels: numpy.ndarray
353
354
355
356
357
358
359
360
361
362
363
        Ground truth labels of each node
    features: Tensor
        Node features

    Notes
    -----
    The node feature is row-normalized.

    Examples
    --------
    >>> dataset = CoraGraphDataset()
364
    >>> g = dataset[0]
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    >>> num_class = g.num_labels
    >>>
    >>> # get node feature
    >>> feat = g.ndata['feat']
    >>>
    >>> # get data split
    >>> train_mask = g.ndata['train_mask']
    >>> val_mask = g.ndata['val_mask']
    >>> test_mask = g.ndata['test_mask']
    >>>
    >>> # get labels
    >>> label = g.ndata['label']
    >>>
    >>> # Train, Validation and Test
Minjie Wang's avatar
Minjie Wang committed
379

380
381
382
    """
    def __init__(self, raw_dir=None, force_reload=False, verbose=True):
        name = 'cora'
383

384
        super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)
385

386
    def __getitem__(self, idx):
387
388
389
390
391
392
393
394
395
        r"""Gets the graph object

        Parameters
        -----------
        idx: int
            Item index, CoraGraphDataset has only one graph object

        Return
        ------
396
397
        :class:`dgl.DGLGraph`

398
            graph structure, node features and labels.
399
400
401
402
403
404

            - ``ndata['train_mask']``: mask for training node set
            - ``ndata['val_mask']``: mask for validation node set
            - ``ndata['test_mask']``: mask for test node set
            - ``ndata['feat']``: node feature
            - ``ndata['label']``: ground truth labels
405
406
        """
        return super(CoraGraphDataset, self).__getitem__(idx)
407
408

    def __len__(self):
409
410
411
412
413
414
415
        r"""The number of graphs in the dataset."""
        return super(CoraGraphDataset, self).__len__()

class CiteseerGraphDataset(CitationGraphDataset):
    r""" Citeseer citation network dataset.

    .. deprecated:: 0.5.0
416

417
        - ``graph`` is deprecated, it is replaced by:
418

419
420
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
421

422
        - ``train_mask`` is deprecated, it is replaced by:
423

424
425
426
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
            >>> train_mask = graph.ndata['train_mask']
427

428
        - ``val_mask`` is deprecated, it is replaced by:
429

430
431
432
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
            >>> val_mask = graph.ndata['val_mask']
433

434
        - ``test_mask`` is deprecated, it is replaced by:
435

436
437
438
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
            >>> test_mask = graph.ndata['test_mask']
439

440
        - ``labels`` is deprecated, it is replaced by:
441

442
443
444
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
            >>> labels = graph.ndata['label']
445

446
        - ``feat`` is deprecated, it is replaced by:
447

448
449
450
451
452
453
454
455
456
457
458
            >>> dataset = CiteseerGraphDataset()
            >>> graph = dataset[0]
            >>> feat = graph.ndata['feat']

    Nodes mean scientific publications and edges
    mean citation relationships. Each node has a
    predefined feature with 3703 dimensions. The
    dataset is designed for the node classification
    task. The task is to predict the category of
    certain publication.

459
460
461
462
463
464
465
466
467
468
    Statistics:

    - Nodes: 3327
    - Edges: 9228
    - Number of Classes: 6
    - Label Split:

        - Train: 120
        - Valid: 500
        - Test: 1000
469

470
471
472
473
474
475
476
477
    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
478
        Whether to print out progress information. Default: True.
479
480
481
482
483
484
485

    Attributes
    ----------
    num_labels: int
        Number of label classes
    graph: networkx.DiGraph
        Graph structure
486
    train_mask: numpy.ndarray
487
        Mask of training nodes
488
    val_mask: numpy.ndarray
489
        Mask of validation nodes
490
    test_mask: numpy.ndarray
491
        Mask of test nodes
492
    labels: numpy.ndarray
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        Ground truth labels of each node
    features: Tensor
        Node features

    Notes
    -----
    The node feature is row-normalized.

    In citeseer dataset, there are some isolated nodes in the graph.
    These isolated nodes are added as zero-vecs into the right position.

    Examples
    --------
    >>> dataset = CiteseerGraphDataset()
507
    >>> g = dataset[0]
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    >>> num_class = g.num_labels
    >>>
    >>> # get node feature
    >>> feat = g.ndata['feat']
    >>>
    >>> # get data split
    >>> train_mask = g.ndata['train_mask']
    >>> val_mask = g.ndata['val_mask']
    >>> test_mask = g.ndata['test_mask']
    >>>
    >>> # get labels
    >>> label = g.ndata['label']
    >>>
    >>> # Train, Validation and Test

    """
    def __init__(self, raw_dir=None, force_reload=False, verbose=True):
        name = 'citeseer'

        super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)

    def __getitem__(self, idx):
        r"""Gets the graph object

        Parameters
        -----------
        idx: int
            Item index, CiteseerGraphDataset has only one graph object

        Return
        ------
539
540
        :class:`dgl.DGLGraph`

541
            graph structure, node features and labels.
542
543
544
545
546
547

            - ``ndata['train_mask']``: mask for training node set
            - ``ndata['val_mask']``: mask for validation node set
            - ``ndata['test_mask']``: mask for test node set
            - ``ndata['feat']``: node feature
            - ``ndata['label']``: ground truth labels
548
549
550
551
552
553
554
555
556
557
558
        """
        return super(CiteseerGraphDataset, self).__getitem__(idx)

    def __len__(self):
        r"""The number of graphs in the dataset."""
        return super(CiteseerGraphDataset, self).__len__()

class PubmedGraphDataset(CitationGraphDataset):
    r""" Pubmed citation network dataset.

    .. deprecated:: 0.5.0
559

560
        - ``graph`` is deprecated, it is replaced by:
561

562
563
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
564

565
        - ``train_mask`` is deprecated, it is replaced by:
566

567
568
569
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
            >>> train_mask = graph.ndata['train_mask']
570

571
        - ``val_mask`` is deprecated, it is replaced by:
572

573
574
575
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
            >>> val_mask = graph.ndata['val_mask']
576

577
        - ``test_mask`` is deprecated, it is replaced by:
578

579
580
581
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
            >>> test_mask = graph.ndata['test_mask']
582

583
        - ``labels`` is deprecated, it is replaced by:
584

585
586
587
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
            >>> labels = graph.ndata['label']
588

589
        - ``feat`` is deprecated, it is replaced by:
590

591
592
593
594
595
596
597
598
599
600
601
            >>> dataset = PubmedGraphDataset()
            >>> graph = dataset[0]
            >>> feat = graph.ndata['feat']

    Nodes mean scientific publications and edges
    mean citation relationships. Each node has a
    predefined feature with 500 dimensions. The
    dataset is designed for the node classification
    task. The task is to predict the category of
    certain publication.

602
603
604
605
606
607
608
609
610
611
    Statistics:

    - Nodes: 19717
    - Edges: 88651
    - Number of Classes: 3
    - Label Split:

        - Train: 60
        - Valid: 500
        - Test: 1000
612
613
614
615
616
617
618
619
620

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
621
        Whether to print out progress information. Default: True.
622
623
624
625
626
627
628

    Attributes
    ----------
    num_labels: int
        Number of label classes
    graph: networkx.DiGraph
        Graph structure
629
    train_mask: numpy.ndarray
630
        Mask of training nodes
631
    val_mask: numpy.ndarray
632
        Mask of validation nodes
633
    test_mask: numpy.ndarray
634
        Mask of test nodes
635
    labels: numpy.ndarray
636
637
638
639
640
641
642
643
644
645
646
        Ground truth labels of each node
    features: Tensor
        Node features

    Notes
    -----
    The node feature is row-normalized.

    Examples
    --------
    >>> dataset = PubmedGraphDataset()
647
    >>> g = dataset[0]
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
    >>> num_class = g.num_of_class
    >>>
    >>> # get node feature
    >>> feat = g.ndata['feat']
    >>>
    >>> # get data split
    >>> train_mask = g.ndata['train_mask']
    >>> val_mask = g.ndata['val_mask']
    >>> test_mask = g.ndata['test_mask']
    >>>
    >>> # get labels
    >>> label = g.ndata['label']
    >>>
    >>> # Train, Validation and Test

    """
    def __init__(self, raw_dir=None, force_reload=False, verbose=True):
        name = 'pubmed'

        super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose)

    def __getitem__(self, idx):
        r"""Gets the graph object

        Parameters
        -----------
        idx: int
            Item index, PubmedGraphDataset has only one graph object

        Return
        ------
679
680
        :class:`dgl.DGLGraph`

681
            graph structure, node features and labels.
682
683
684
685
686
687

            - ``ndata['train_mask']``: mask for training node set
            - ``ndata['val_mask']``: mask for validation node set
            - ``ndata['test_mask']``: mask for test node set
            - ``ndata['feat']``: node feature
            - ``ndata['label']``: ground truth labels
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
        """
        return super(PubmedGraphDataset, self).__getitem__(idx)

    def __len__(self):
        r"""The number of graphs in the dataset."""
        return super(PubmedGraphDataset, self).__len__()

def load_cora(raw_dir=None, force_reload=False, verbose=True):
    """Get CoraGraphDataset

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
    Whether to print out progress information. Default: True.

    Return
    -------
    CoraGraphDataset
    """
    data = CoraGraphDataset(raw_dir, force_reload, verbose)
    return data

def load_citeseer(raw_dir=None, force_reload=False, verbose=True):
    """Get CiteseerGraphDataset

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
    Whether to print out progress information. Default: True.

    Return
    -------
    CiteseerGraphDataset
    """
    data = CiteseerGraphDataset(raw_dir, force_reload, verbose)
    return data

def load_pubmed(raw_dir=None, force_reload=False, verbose=True):
    """Get PubmedGraphDataset

    Parameters
    -----------
        raw_dir : str
            Raw file directory to download/contains the input data directory.
            Default: ~/.dgl/
        force_reload : bool
            Whether to reload the dataset. Default: False
        verbose: bool
        Whether to print out progress information. Default: True.

    Return
    -------
    PubmedGraphDataset
    """
    data = PubmedGraphDataset(raw_dir, force_reload, verbose)
    return data

class CoraBinary(DGLBuiltinDataset):
HQ's avatar
HQ committed
756
757
758
759
760
761
762
    """A mini-dataset for binary classification task using Cora.

    After loaded, it has following members:

    graphs : list of :class:`~dgl.DGLGraph`
    pmpds : list of :class:`scipy.sparse.coo_matrix`
    labels : list of :class:`numpy.ndarray`
763
764
765
766
767
768
769
770
771
772

    Parameters
    -----------
    raw_dir : str
        Raw file directory to download/contains the input data directory.
        Default: ~/.dgl/
    force_reload : bool
        Whether to reload the dataset. Default: False
    verbose: bool
        Whether to print out progress information. Default: True.
HQ's avatar
HQ committed
773
    """
774
775
776
777
778
779
780
781
782
783
784
    def __init__(self, raw_dir=None, force_reload=False, verbose=True):
        name = 'cora_binary'
        url = _get_dgl_url('dataset/cora_binary.zip')
        super(CoraBinary, self).__init__(name,
                                         url=url,
                                         raw_dir=raw_dir,
                                         force_reload=force_reload,
                                         verbose=verbose)

    def process(self):
        root = self.raw_path
HQ's avatar
HQ committed
785
786
787
788
789
790
791
        # load graphs
        self.graphs = []
        with open("{}/graphs.txt".format(root), 'r') as f:
            elist = []
            for line in f.readlines():
                if line.startswith('graph'):
                    if len(elist) != 0:
792
                        self.graphs.append(dgl_graph(elist))
HQ's avatar
HQ committed
793
794
795
796
797
                    elist = []
                else:
                    u, v = line.strip().split(' ')
                    elist.append((int(u), int(v)))
            if len(elist) != 0:
798
                self.graphs.append(dgl_graph(elist))
HQ's avatar
HQ committed
799
800
801
802
803
804
805
806
        with open("{}/pmpds.pkl".format(root), 'rb') as f:
            self.pmpds = _pickle_load(f)
        self.labels = []
        with open("{}/labels.txt".format(root), 'r') as f:
            cur = []
            for line in f.readlines():
                if line.startswith('graph'):
                    if len(cur) != 0:
807
                        self.labels.append(np.asarray(cur))
HQ's avatar
HQ committed
808
809
810
811
                    cur = []
                else:
                    cur.append(int(line.strip()))
            if len(cur) != 0:
812
                self.labels.append(np.asarray(cur))
HQ's avatar
HQ committed
813
814
815
816
        # sanity check
        assert len(self.graphs) == len(self.pmpds)
        assert len(self.graphs) == len(self.labels)

817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
    def has_cache(self):
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        if os.path.exists(graph_path):
            return True

        return False

    def save(self):
        """save the graph list and the labels"""
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        labels = {}
        for i, label in enumerate(self.labels):
            labels['{}'.format(i)] = F.tensor(label)
        save_graphs(str(graph_path), self.graphs, labels)
        if self.verbose:
            print('Done saving data into cached files.')

    def load(self):
        graph_path = os.path.join(self.save_path,
                                  self.save_name + '.bin')
        self.graphs, labels = load_graphs(str(graph_path))

        self.labels = []
        for i in range(len(lables)):
            self.labels.append(labels['{}'.format(i)].asnumpy())
        # load pmpds under self.raw_path
        with open("{}/pmpds.pkl".format(self.raw_path), 'rb') as f:
            self.pmpds = _pickle_load(f)
        if self.verbose:
            print('Done loading data into cached files.')
        # sanity check
        assert len(self.graphs) == len(self.pmpds)
        assert len(self.graphs) == len(self.labels)

HQ's avatar
HQ committed
853
854
855
856
    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, i):
857
858
859
860
861
862
863
864
865
866
867
868
        r"""Gets the idx-th sample.

        Parameters
        -----------
        idx : int
            The sample index.

        Returns
        -------
        (dgl.DGLGraph, scipy.sparse.coo_matrix, int)
            The graph, scipy sparse coo_matrix and its label.
        """
HQ's avatar
HQ committed
869
870
        return (self.graphs[i], self.pmpds[i], self.labels[i])

871
872
873
874
    @property
    def save_name(self):
        return self.name + '_dgl_graph'

HQ's avatar
HQ committed
875
    @staticmethod
876
877
878
    def collate_fn(cur):
        graphs, pmpds, labels = zip(*cur)
        batched_graphs = batch.batch(graphs)
HQ's avatar
HQ committed
879
880
881
        batched_pmpds = sp.block_diag(pmpds)
        batched_labels = np.concatenate(labels, axis=0)
        return batched_graphs, batched_pmpds, batched_labels
882
883
884

def _normalize(mx):
    """Row-normalize sparse matrix"""
885
    rowsum = np.asarray(mx.sum(1))
886
    r_inv = np.power(rowsum, -1).flatten()
Zhengwei's avatar
Zhengwei committed
887
    r_inv[np.isinf(r_inv)] = 0.
888
889
890
891
892
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def _encode_onehot(labels):
893
    classes = list(sorted(set(labels)))
894
895
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
896
897
    labels_onehot = np.asarray(list(map(classes_dict.get, labels)),
                               dtype=np.int32)
898
    return labels_onehot