Unverified Commit 967ecb80 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Dataset] Fix the docstring format for dgl.data section (#1941)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* update data rst

* update data doc and docstring

* API doc rst for dataset

* docstring

* update api doc

* add url format

* update docstring

* update citation graph

* update knowledge graph

* update gc datasets

* fix index

* Rst fix (#3)

* Fix syntax

* syntax

* update docstring

* update doc (#4)

* final update

* fix rdflib

* fix rdf
Co-authored-by: default avatarHuXiangkun <huxk_hit@qq.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 3fa8d755
...@@ -160,3 +160,4 @@ cscope.* ...@@ -160,3 +160,4 @@ cscope.*
config.cmake config.cmake
.ycm_extra_conf.py .ycm_extra_conf.py
**.png
...@@ -5,89 +5,125 @@ dgl.data ...@@ -5,89 +5,125 @@ dgl.data
.. currentmodule:: dgl.data .. currentmodule:: dgl.data
Utils Dataset Classes
----- ---------------
.. autosummary:: DGL dataset
:toctree: ../../generated/ ```````````
utils.get_download_dir .. autoclass:: DGLDataset
utils.download :members: download, save, load, process, has_cache, __getitem__, __len__
utils.check_sha1
utils.extract_archive
utils.split_dataset
utils.save_graphs
utils.load_graphs
utils.load_labels
.. autoclass:: dgl.data.utils.Subset DGL builtin dataset
:members: __getitem__, __len__ ```````````````````
Dataset Classes .. autoclass:: DGLBuiltinDataset
--------------- :members: download
Stanford sentiment treebank dataset Stanford sentiment treebank dataset
``````````````````````````````````` ```````````````````````````````````
For more information about the dataset, see `Sentiment Analysis <https://nlp.stanford.edu/sentiment/index.html>`__. For more information about the dataset, see `Sentiment Analysis <https://nlp.stanford.edu/sentiment/index.html>`__.
.. autoclass:: SST .. autoclass:: SSTDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
Karate Club dataset Karate club dataset
``````````````````````````````````` ```````````````````````````````````
.. autoclass:: KarateClub .. autoclass:: KarateClubDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
Citation Network dataset Citation network dataset
``````````````````````````````````` ```````````````````````````````````
.. autoclass:: CitationGraphDataset .. autoclass:: CoraGraphDataset
:members: __getitem__, __len__
.. autoclass:: CiteseerGraphDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
.. autoclass:: PubmedGraphDataset
:members: __getitem__, __len__
Knowlege graph dataset
```````````````````````````````````
.. autoclass:: FB15k237Dataset
:members: __getitem__, __len__
.. autoclass:: FB15kDataset
:members: __getitem__, __len__
.. autoclass:: WN18Dataset
:members: __getitem__, __len__
RDF datasets
```````````````````````````````````
.. autoclass:: AIFBDataset
:members: __getitem__, __len__
.. autoclass:: MUTAGDataset
:members: __getitem__, __len__
.. autoclass:: BGSDataset
:members: __getitem__, __len__
.. autoclass:: AMDataset
:members: __getitem__, __len__
CoraFull dataset CoraFull dataset
``````````````````````````````````` ```````````````````````````````````
.. autoclass:: CoraFull .. autoclass:: CoraFullDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
Amazon Co-Purchase dataset Amazon Co-Purchase dataset
``````````````````````````````````` ```````````````````````````````````
.. autoclass:: AmazonCoBuy .. autoclass:: AmazonCoBuyComputerDataset
:members: __getitem__, __len__
.. autoclass:: AmazonCoBuyPhotoDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
Coauthor dataset Coauthor dataset
``````````````````````````````````` ```````````````````````````````````
.. autoclass:: Coauthor .. autoclass:: CoauthorCSDataset
:members: __getitem__, __len__
.. autoclass:: CoauthorPhysicsDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
BitcoinOTC dataset BitcoinOTC dataset
``````````````````````````````````` ```````````````````````````````````
.. autoclass:: BitcoinOTC .. autoclass:: BitcoinOTCDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
ICEWS18 dataset ICEWS18 dataset
``````````````````````````````````` ```````````````````````````````````
.. autoclass:: ICEWS18 .. autoclass:: ICEWS18Dataset
:members: __getitem__, __len__ :members: __getitem__, __len__
QM7b dataset QM7b dataset
``````````````````````````````````` ```````````````````````````````````
.. autoclass:: QM7b .. autoclass:: QM7bDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
...@@ -95,7 +131,7 @@ QM7b dataset ...@@ -95,7 +131,7 @@ QM7b dataset
GDELT dataset GDELT dataset
``````````````````````````````````` ```````````````````````````````````
.. autoclass:: GDELT .. autoclass:: GDELTDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
...@@ -103,17 +139,16 @@ Mini graph classification dataset ...@@ -103,17 +139,16 @@ Mini graph classification dataset
````````````````````````````````` `````````````````````````````````
.. autoclass:: MiniGCDataset .. autoclass:: MiniGCDataset
:members: __getitem__, __len__, num_classes :members: __getitem__, __len__
Graph kernel dataset
````````````````````
For more information about the dataset, see `Benchmark Data Sets for Graph Kernels <https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets>`__. TU dataset
``````````
.. autoclass:: TUDataset .. autoclass:: TUDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
.. autoclass:: LegacyTUDataset
:members: __getitem__, __len__
Graph isomorphism network dataset Graph isomorphism network dataset
``````````````````````````````````` ```````````````````````````````````
...@@ -129,3 +164,36 @@ Protein-Protein Interaction dataset ...@@ -129,3 +164,36 @@ Protein-Protein Interaction dataset
.. autoclass:: PPIDataset .. autoclass:: PPIDataset
:members: __getitem__, __len__ :members: __getitem__, __len__
Reddit dataset
``````````````
.. autoclass:: RedditDataset
:members: __getitem__, __len__
Symmetric Stochastic Block Model Mixture dataset
````````````````````````````````````````````````
.. autoclass:: SBMMixtureDataset
:members: __getitem__, __len__, collate_fn
Utils
-----
.. autosummary::
:toctree: ../../generated/
utils.get_download_dir
utils.download
utils.check_sha1
utils.extract_archive
utils.split_dataset
utils.save_graphs
utils.load_graphs
utils.load_labels
.. autoclass:: dgl.data.utils.Subset
:members: __getitem__, __len__
...@@ -20,6 +20,8 @@ from .icews18 import ICEWS18, ICEWS18Dataset ...@@ -20,6 +20,8 @@ from .icews18 import ICEWS18, ICEWS18Dataset
from .qm7b import QM7b, QM7bDataset from .qm7b import QM7b, QM7bDataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset from .dgl_dataset import DGLDataset, DGLBuiltinDataset
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset
from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
def register_data_args(parser): def register_data_args(parser):
......
...@@ -18,13 +18,15 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -18,13 +18,15 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
a platform called Bitcoin OTC. Since Bitcoin users are anonymous, a platform called Bitcoin OTC. Since Bitcoin users are anonymous,
there is a need to maintain a record of users' reputation to prevent there is a need to maintain a record of users' reputation to prevent
transactions with fraudulent and risky users. transactions with fraudulent and risky users.
Offical website: https://snap.stanford.edu/data/soc-sign-bitcoin-otc.html
Offical website: `<https://snap.stanford.edu/data/soc-sign-bitcoin-otc.html>`_
Bitcoin OTC dataset statistics: Bitcoin OTC dataset statistics:
Nodes: 5,881
Edges: 35,592 - Nodes: 5,881
Range of edge weight: -10 to +10 - Edges: 35,592
Percentage of positive edges: 89% - Range of edge weight: -10 to +10
- Percentage of positive edges: 89%
Parameters Parameters
---------- ----------
...@@ -117,7 +119,12 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -117,7 +119,12 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
return self._graphs return self._graphs
def __len__(self): def __len__(self):
r""" Number of graphs in the dataset """ r""" Number of graphs in the dataset.
Return
-------
int
"""
return len(self.graphs) return len(self.graphs)
def __getitem__(self, item): def __getitem__(self, item):
...@@ -130,9 +137,11 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -130,9 +137,11 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
Returns Returns
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
The graph contains the graph structure and edge weights
- edata['h'] : edge weights The graph contains:
- ``edata['h']`` : edge weights
""" """
return self.graphs[item] return self.graphs[item]
......
...@@ -273,26 +273,38 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -273,26 +273,38 @@ class CoraGraphDataset(CitationGraphDataset):
r""" Cora citation network dataset. r""" Cora citation network dataset.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset() >>> dataset = CoraGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
- ``train_mask`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset() >>> dataset = CoraGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask'] >>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
- ``val_mask`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset() >>> dataset = CoraGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask'] >>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
- ``test_mask`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset() >>> dataset = CoraGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask'] >>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
- ``labels`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset() >>> dataset = CoraGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> labels = graph.ndata['label'] >>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
- ``feat`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset() >>> dataset = CoraGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> feat = graph.ndata['feat'] >>> feat = graph.ndata['feat']
...@@ -304,12 +316,16 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -304,12 +316,16 @@ class CoraGraphDataset(CitationGraphDataset):
The task is to predict the category of The task is to predict the category of
certain paper. certain paper.
Statistics Statistics:
----------
Nodes: 2708 - Nodes: 2708
Edges: 10556 - Edges: 10556
Number of Classes: 7 - Number of Classes: 7
Label Split: Train: 140 ,Valid: 500, Test: 1000 - Label split:
- Train: 140
- Valid: 500
- Test: 1000
Parameters Parameters
---------- ----------
...@@ -319,7 +335,7 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -319,7 +335,7 @@ class CoraGraphDataset(CitationGraphDataset):
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
Attributes Attributes
---------- ----------
...@@ -327,13 +343,13 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -327,13 +343,13 @@ class CoraGraphDataset(CitationGraphDataset):
Number of label classes Number of label classes
graph: networkx.DiGraph graph: networkx.DiGraph
Graph structure Graph structure
train_mask: Numpy array train_mask: numpy.ndarray
Mask of training nodes Mask of training nodes
val_mask: Numpy array val_mask: numpy.ndarray
Mask of validation nodes Mask of validation nodes
test_mask: Numpy array test_mask: numpy.ndarray
Mask of test nodes Mask of test nodes
labels: Numpy array labels: numpy.ndarray
Ground truth labels of each node Ground truth labels of each node
features: Tensor features: Tensor
Node features Node features
...@@ -377,13 +393,15 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -377,13 +393,15 @@ class CoraGraphDataset(CitationGraphDataset):
Return Return
------ ------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set - ``ndata['train_mask']``: mask for training node set
- ndata['test_mask']: mask for test node set - ``ndata['val_mask']``: mask for validation node set
- ndata['feat']: node feature - ``ndata['test_mask']``: mask for test node set
- ndata['label']: ground truth labels - ``ndata['feat']``: node feature
- ``ndata['label']``: ground truth labels
""" """
return super(CoraGraphDataset, self).__getitem__(idx) return super(CoraGraphDataset, self).__getitem__(idx)
...@@ -395,26 +413,38 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -395,26 +413,38 @@ class CiteseerGraphDataset(CitationGraphDataset):
r""" Citeseer citation network dataset. r""" Citeseer citation network dataset.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
``graph`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset() >>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
``train_mask`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset() >>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask'] >>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
``val_mask`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset() >>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask'] >>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
``test_mask`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset() >>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask'] >>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
``labels`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset() >>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> labels = graph.ndata['label'] >>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
``feat`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset() >>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> feat = graph.ndata['feat'] >>> feat = graph.ndata['feat']
...@@ -426,12 +456,16 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -426,12 +456,16 @@ class CiteseerGraphDataset(CitationGraphDataset):
task. The task is to predict the category of task. The task is to predict the category of
certain publication. certain publication.
Statistics Statistics:
----------
Nodes: 3327 - Nodes: 3327
Edges: 9228 - Edges: 9228
Number of Classes: 6 - Number of Classes: 6
Label Split: Train: 120 ,Valid: 500, Test: 1000 - Label Split:
- Train: 120
- Valid: 500
- Test: 1000
Parameters Parameters
----------- -----------
...@@ -441,7 +475,7 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -441,7 +475,7 @@ class CiteseerGraphDataset(CitationGraphDataset):
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
Attributes Attributes
---------- ----------
...@@ -449,13 +483,13 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -449,13 +483,13 @@ class CiteseerGraphDataset(CitationGraphDataset):
Number of label classes Number of label classes
graph: networkx.DiGraph graph: networkx.DiGraph
Graph structure Graph structure
train_mask: Numpy array train_mask: numpy.ndarray
Mask of training nodes Mask of training nodes
val_mask: Numpy array val_mask: numpy.ndarray
Mask of validation nodes Mask of validation nodes
test_mask: Numpy array test_mask: numpy.ndarray
Mask of test nodes Mask of test nodes
labels: Numpy array labels: numpy.ndarray
Ground truth labels of each node Ground truth labels of each node
features: Tensor features: Tensor
Node features Node features
...@@ -502,13 +536,15 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -502,13 +536,15 @@ class CiteseerGraphDataset(CitationGraphDataset):
Return Return
------ ------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set - ``ndata['train_mask']``: mask for training node set
- ndata['test_mask']: mask for test node set - ``ndata['val_mask']``: mask for validation node set
- ndata['feat']: node feature - ``ndata['test_mask']``: mask for test node set
- ndata['label']: ground truth labels - ``ndata['feat']``: node feature
- ``ndata['label']``: ground truth labels
""" """
return super(CiteseerGraphDataset, self).__getitem__(idx) return super(CiteseerGraphDataset, self).__getitem__(idx)
...@@ -520,26 +556,38 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -520,26 +556,38 @@ class PubmedGraphDataset(CitationGraphDataset):
r""" Pubmed citation network dataset. r""" Pubmed citation network dataset.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
``graph`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset() >>> dataset = PubmedGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
``train_mask`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset() >>> dataset = PubmedGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask'] >>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
``val_mask`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset() >>> dataset = PubmedGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask'] >>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
``test_mask`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset() >>> dataset = PubmedGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask'] >>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
``labels`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset() >>> dataset = PubmedGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> labels = graph.ndata['label'] >>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
``feat`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset() >>> dataset = PubmedGraphDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> feat = graph.ndata['feat'] >>> feat = graph.ndata['feat']
...@@ -551,12 +599,16 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -551,12 +599,16 @@ class PubmedGraphDataset(CitationGraphDataset):
task. The task is to predict the category of task. The task is to predict the category of
certain publication. certain publication.
Statistics Statistics:
----------
Nodes: 19717 - Nodes: 19717
Edges: 88651 - Edges: 88651
Number of Classes: 3 - Number of Classes: 3
Label Split: Train: 60 ,Valid: 500, Test: 1000 - Label Split:
- Train: 60
- Valid: 500
- Test: 1000
Parameters Parameters
----------- -----------
...@@ -566,7 +618,7 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -566,7 +618,7 @@ class PubmedGraphDataset(CitationGraphDataset):
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
Attributes Attributes
---------- ----------
...@@ -574,13 +626,13 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -574,13 +626,13 @@ class PubmedGraphDataset(CitationGraphDataset):
Number of label classes Number of label classes
graph: networkx.DiGraph graph: networkx.DiGraph
Graph structure Graph structure
train_mask: Numpy array train_mask: numpy.ndarray
Mask of training nodes Mask of training nodes
val_mask: Numpy array val_mask: numpy.ndarray
Mask of validation nodes Mask of validation nodes
test_mask: Numpy array test_mask: numpy.ndarray
Mask of test nodes Mask of test nodes
labels: Numpy array labels: numpy.ndarray
Ground truth labels of each node Ground truth labels of each node
features: Tensor features: Tensor
Node features Node features
...@@ -624,13 +676,15 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -624,13 +676,15 @@ class PubmedGraphDataset(CitationGraphDataset):
Return Return
------ ------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, node features and labels. graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set - ``ndata['train_mask']``: mask for training node set
- ndata['test_mask']: mask for test node set - ``ndata['val_mask']``: mask for validation node set
- ndata['feat']: node feature - ``ndata['test_mask']``: mask for test node set
- ndata['label']: ground truth labels - ``ndata['feat']``: node feature
- ``ndata['label']``: ground truth labels
""" """
return super(PubmedGraphDataset, self).__getitem__(idx) return super(PubmedGraphDataset, self).__getitem__(idx)
......
...@@ -14,13 +14,13 @@ class DGLDataset(object): ...@@ -14,13 +14,13 @@ class DGLDataset(object):
The following steps will are executed automatically: The following steps will are executed automatically:
1. Check whether there is a dataset cache on disk 1. Check whether there is a dataset cache on disk
(already processed and stored on the disk) by (already processed and stored on the disk) by
invoking ``has_cache()``. If true, goto 5. invoking ``has_cache()``. If true, goto 5.
2. Call ``download()`` to download the data. 2. Call ``download()`` to download the data.
3. Call ``process()`` to process the data. 3. Call ``process()`` to process the data.
4. Call ``save()`` to save the processed dataset on disk and goto 6. 4. Call ``save()`` to save the processed dataset on disk and goto 6.
5. Call ``load()`` to load the processed dataset from disk. 5. Call ``load()`` to load the processed dataset from disk.
6. Done 6. Done.
Users can overwite these functions with their Users can overwite these functions with their
own data processing logic. own data processing logic.
...@@ -43,11 +43,31 @@ class DGLDataset(object): ...@@ -43,11 +43,31 @@ class DGLDataset(object):
A tuple of values as the input for the hash function. A tuple of values as the input for the hash function.
Users can distinguish instances (and their caches on the disk) Users can distinguish instances (and their caches on the disk)
from the same dataset class by comparing the hash values. from the same dataset class by comparing the hash values.
Default: (), the corresponding hash value is 'f9065fa7'. Default: (), the corresponding hash value is ``'f9065fa7'``.
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose : bool verbose : bool
Whether to print out progress information Whether to print out progress information
Attributes
----------
url : str
The URL to download the dataset
name : str
The dataset name
raw_dir : str
Raw file directory contains the input data folder
raw_path : str
Directory contains the input data files.
Default : ``os.path.join(self.raw_dir, self.name)``
save_dir : str
Directory to save the processed dataset
save_path : str
Filr path to save the processed dataset
verbose : bool
Whether to print information
hash : str
Hash value for the dataset and the setting.
""" """
def __init__(self, name, url=None, raw_dir=None, save_dir=None, def __init__(self, name, url=None, raw_dir=None, save_dir=None,
hash_key=(), force_reload=False, verbose=False): hash_key=(), force_reload=False, verbose=False):
...@@ -75,7 +95,7 @@ class DGLDataset(object): ...@@ -75,7 +95,7 @@ class DGLDataset(object):
It is recommended to download the to the :obj:`self.raw_dir` It is recommended to download the to the :obj:`self.raw_dir`
folder. Can be ignored if the dataset is folder. Can be ignored if the dataset is
already in self.raw_dir already in :obj:`self.raw_dir`.
""" """
pass pass
...@@ -83,9 +103,9 @@ class DGLDataset(object): ...@@ -83,9 +103,9 @@ class DGLDataset(object):
r"""Overwite to realize your own logic of r"""Overwite to realize your own logic of
saving the processed dataset into files. saving the processed dataset into files.
It is recommended to use dgl.utils.data.save_graphs It is recommended to use ``dgl.utils.data.save_graphs``
to save dgl graph into files and use to save dgl graph into files and use
dgl.utils.data.save_info to save extra ``dgl.utils.data.save_info`` to save extra
information into files. information into files.
""" """
pass pass
...@@ -94,9 +114,9 @@ class DGLDataset(object): ...@@ -94,9 +114,9 @@ class DGLDataset(object):
r"""Overwite to realize your own logic of r"""Overwite to realize your own logic of
loading the saved dataset from files. loading the saved dataset from files.
It is recommended to use dgl.utils.data.load_graphs It is recommended to use ``dgl.utils.data.load_graphs``
to load dgl graph from files and use to load dgl graph from files and use
dgl.utils.data.load_info to load extra information ``dgl.utils.data.load_info`` to load extra information
into python dict object. into python dict object.
""" """
pass pass
...@@ -116,9 +136,9 @@ class DGLDataset(object): ...@@ -116,9 +136,9 @@ class DGLDataset(object):
@retry_method_with_fix(download) @retry_method_with_fix(download)
def _download(self): def _download(self):
r"""Download dataset by calling self.download() if the dataset does not exists under self.raw_path. r"""Download dataset by calling ``self.download()`` if the dataset does not exists under ``self.raw_path``.
By default self.raw_path = os.path.join(self.raw_dir, self.name) By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``
One can overwrite raw_path() function to change the path. One can overwrite ``raw_path()`` function to change the path.
""" """
if os.path.exists(self.raw_path): # pragma: no cover if os.path.exists(self.raw_path): # pragma: no cover
return return
...@@ -185,13 +205,13 @@ class DGLDataset(object): ...@@ -185,13 +205,13 @@ class DGLDataset(object):
@property @property
def raw_dir(self): def raw_dir(self):
r"""Raw file directory contains the input data directory. r"""Raw file directory contains the input data folder.
""" """
return self._raw_dir return self._raw_dir
@property @property
def raw_path(self): def raw_path(self):
r"""File directory contains the input data. r"""Directory contains the input data files.
By default raw_path = os.path.join(self.raw_dir, self.name) By default raw_path = os.path.join(self.raw_dir, self.name)
""" """
return os.path.join(self.raw_dir, self.name) return os.path.join(self.raw_dir, self.name)
...@@ -216,7 +236,7 @@ class DGLDataset(object): ...@@ -216,7 +236,7 @@ class DGLDataset(object):
@property @property
def hash(self): def hash(self):
r"""Hash value for the dataset. r"""Hash value for the dataset and the setting.
""" """
return self._hash return self._hash
...@@ -235,6 +255,7 @@ class DGLBuiltinDataset(DGLDataset): ...@@ -235,6 +255,7 @@ class DGLBuiltinDataset(DGLDataset):
r"""The Basic DGL Builtin Dataset. r"""The Basic DGL Builtin Dataset.
Parameters Parameters
----------
name : str name : str
Name of the dataset. Name of the dataset.
url : str url : str
......
...@@ -18,16 +18,15 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -18,16 +18,15 @@ class GDELTDataset(DGLBuiltinDataset):
(15 minutes time granularity). (15 minutes time granularity).
Reference: Reference:
- `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs` <https://arxiv.org/abs/1904.05530>
- `The Global Database of Events, Language, and Tone (GDELT) `
<https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>
Statistics - `Recurrent Event Network for Reasoning over Temporal Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
---------- - `The Global Database of Events, Language, and Tone (GDELT) <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
Train examples: 2,304
Valid examples: 288 Statistics:
Test examples: 384
- Train examples: 2,304
- Valid examples: 288
- Test examples: 384
Parameters Parameters
---------- ----------
...@@ -135,9 +134,11 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -135,9 +134,11 @@ class GDELTDataset(DGLBuiltinDataset):
Returns Returns
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure and edge feature
- edata['rel_type']: edge type The graph contains:
- ``edata['rel_type']``: edge type
""" """
if t >= len(self) or t < 0: if t >= len(self) or t < 0:
raise IndexError("Index out of range") raise IndexError("Index out of range")
...@@ -150,7 +151,12 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -150,7 +151,12 @@ class GDELTDataset(DGLBuiltinDataset):
return g return g
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset""" r"""Number of graphs in the dataset.
Return
-------
int
"""
return self._end_time - self._start_time + 1 return self._end_time - self._start_time + 1
@property @property
......
...@@ -17,21 +17,19 @@ from ..convert import graph as dgl_graph ...@@ -17,21 +17,19 @@ from ..convert import graph as dgl_graph
class GINDataset(DGLBuiltinDataset): class GINDataset(DGLBuiltinDataset):
"""Datasets for Graph Isomorphism Network (GIN) """Datasets for Graph Isomorphism Network (GIN)
Adapted from https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip. Adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.
The dataset contains the compact format of popular graph kernel datasets, which includes: The dataset contains the compact format of popular graph kernel datasets.
MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K For more graph kernel datasets, see :class:`TUDataset`.
This datset class processes all data sets listed above. For more graph kernel datasets,
see :class:`TUDataset`.
Paramters Parameters
--------- ---------
name: str name: str
dataset name, one of below - dataset name, one of
('MUTAG', 'COLLAB', \ (``'MUTAG'``, ``'COLLAB'``, \
'IMDBBINARY', 'IMDBMULTI', \ ``'IMDBBINARY'``, ``'IMDBMULTI'``, \
'NCI1', 'PROTEINS', 'PTC', \ ``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, \
'REDDITBINARY', 'REDDITMULTI5K') ``'REDDITBINARY'``, ``'REDDITMULTI5K'``)
self_loop: bool self_loop: bool
add self to self edge if true add self to self edge if true
degree_as_nlabel: bool degree_as_nlabel: bool
...@@ -41,7 +39,7 @@ class GINDataset(DGLBuiltinDataset): ...@@ -41,7 +39,7 @@ class GINDataset(DGLBuiltinDataset):
-------- --------
>>> data = GINDataset(name='MUTAG', self_loop=False) >>> data = GINDataset(name='MUTAG', self_loop=False)
**The dataset instance is an iterable** The dataset instance is an iterable
>>> len(data) >>> len(data)
188 188
...@@ -53,7 +51,7 @@ class GINDataset(DGLBuiltinDataset): ...@@ -53,7 +51,7 @@ class GINDataset(DGLBuiltinDataset):
>>> label >>> label
tensor(1) tensor(1)
**Batch the graphs and labels for mini-batch training** Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)]) >>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs) >>> batched_graphs = dgl.batch(graphs)
...@@ -118,14 +116,14 @@ class GINDataset(DGLBuiltinDataset): ...@@ -118,14 +116,14 @@ class GINDataset(DGLBuiltinDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the idx-th sample. """Get the idx-th sample.
Paramters Parameters
--------- ---------
idx : int idx : int
The sample index. The sample index.
Returns Returns
------- -------
(dgl.Graph, int) (:class:`dgl.Graph`, Tensor)
The graph and its label. The graph and its label.
""" """
return self.graphs[idx], self.labels[idx] return self.graphs[idx], self.labels[idx]
......
...@@ -118,10 +118,12 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -118,10 +118,12 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
Returns Returns
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, node features and node labels
- ndata['feat']: node features The graph contains:
- ndata['label']: node labels
- ``ndata['feat']``: node features
- ``ndata['label']``: node labels
""" """
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
return self._graph return self._graph
...@@ -135,7 +137,9 @@ class CoraFullDataset(GNNBenchmarkDataset): ...@@ -135,7 +137,9 @@ class CoraFullDataset(GNNBenchmarkDataset):
r"""CORA-Full dataset for node classification task. r"""CORA-Full dataset for node classification task.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
- ``data`` is deprecated, it is repalced by:
>>> dataset = CoraFullDataset() >>> dataset = CoraFullDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
...@@ -143,14 +147,14 @@ class CoraFullDataset(GNNBenchmarkDataset): ...@@ -143,14 +147,14 @@ class CoraFullDataset(GNNBenchmarkDataset):
Unsupervised Inductive Learning via Ranking`. Unsupervised Inductive Learning via Ranking`.
Nodes represent paper and edges represent citations. Nodes represent paper and edges represent citations.
Reference: https://github.com/shchur/gnn-benchmark#datasets Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_
Statistics Statistics:
----------
Nodes: 19,793 - Nodes: 19,793
Edges: 130,622 - Edges: 130,622
Number of Classes: 70 - Number of Classes: 70
Node feature size: 8,710 - Node feature size: 8,710
Parameters Parameters
---------- ----------
...@@ -185,7 +189,12 @@ class CoraFullDataset(GNNBenchmarkDataset): ...@@ -185,7 +189,12 @@ class CoraFullDataset(GNNBenchmarkDataset):
@property @property
def num_classes(self): def num_classes(self):
"""Number of classes.""" """Number of classes.
Return
-------
int
"""
return 70 return 70
...@@ -193,7 +202,9 @@ class CoauthorCSDataset(GNNBenchmarkDataset): ...@@ -193,7 +202,9 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
r""" 'Computer Science (CS)' part of the Coauthor dataset for node classification task. r""" 'Computer Science (CS)' part of the Coauthor dataset for node classification task.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
- ``data`` is deprecated, it is repalced by:
>>> dataset = CoauthorCSDataset() >>> dataset = CoauthorCSDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
...@@ -202,14 +213,14 @@ class CoauthorCSDataset(GNNBenchmarkDataset): ...@@ -202,14 +213,14 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
co-authored a paper; node features represent paper keywords for each author’s papers, and class co-authored a paper; node features represent paper keywords for each author’s papers, and class
labels indicate most active fields of study for each author. labels indicate most active fields of study for each author.
Reference: https://github.com/shchur/gnn-benchmark#datasets Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_
Statistics Statistics:
----------
Nodes: 18,333 - Nodes: 18,333
Edges: 327,576 - Edges: 327,576
Number of classes: 15 - Number of classes: 15
Node feature size: 6,805 - Node feature size: 6,805
Parameters Parameters
---------- ----------
...@@ -244,7 +255,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset): ...@@ -244,7 +255,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
@property @property
def num_classes(self): def num_classes(self):
"""Number of classes.""" """Number of classes.
Return
-------
int
"""
return 15 return 15
...@@ -252,7 +268,9 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset): ...@@ -252,7 +268,9 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
r""" 'Physics' part of the Coauthor dataset for node classification task. r""" 'Physics' part of the Coauthor dataset for node classification task.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
- ``data`` is deprecated, it is repalced by:
>>> dataset = CoauthorPhysicsDataset() >>> dataset = CoauthorPhysicsDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
...@@ -261,14 +279,14 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset): ...@@ -261,14 +279,14 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
co-authored a paper; node features represent paper keywords for each author’s papers, and class co-authored a paper; node features represent paper keywords for each author’s papers, and class
labels indicate most active fields of study for each author. labels indicate most active fields of study for each author.
Reference: https://github.com/shchur/gnn-benchmark#datasets Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_
Statistics Statistics
----------
Nodes: 34,493 - Nodes: 34,493
Edges: 991,848 - Edges: 991,848
Number of classes: 5 - Number of classes: 5
Node feature size: 8,415 - Node feature size: 8,415
Parameters Parameters
---------- ----------
...@@ -303,7 +321,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset): ...@@ -303,7 +321,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
@property @property
def num_classes(self): def num_classes(self):
"""Number of classes.""" """Number of classes.
Return
-------
int
"""
return 5 return 5
...@@ -311,7 +334,9 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset): ...@@ -311,7 +334,9 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
r""" 'Computer' part of the AmazonCoBuy dataset for node classification task. r""" 'Computer' part of the AmazonCoBuy dataset for node classification task.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
- ``data`` is deprecated, it is repalced by:
>>> dataset = AmazonCoBuyComputerDataset() >>> dataset = AmazonCoBuyComputerDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
...@@ -319,14 +344,14 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset): ...@@ -319,14 +344,14 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
where nodes represent goods, edges indicate that two goods are frequently bought together, node where nodes represent goods, edges indicate that two goods are frequently bought together, node
features are bag-of-words encoded product reviews, and class labels are given by the product category. features are bag-of-words encoded product reviews, and class labels are given by the product category.
Reference: https://github.com/shchur/gnn-benchmark#datasets Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_
Statistics Statistics:
----------
Nodes: 13,752 - Nodes: 13,752
Edges: 574,418 - Edges: 574,418
Number of classes: 5 - Number of classes: 5
Node feature size: 767 - Node feature size: 767
Parameters Parameters
---------- ----------
...@@ -361,7 +386,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset): ...@@ -361,7 +386,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
@property @property
def num_classes(self): def num_classes(self):
"""Number of classes.""" """Number of classes.
Return
-------
int
"""
return 5 return 5
...@@ -369,7 +399,9 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset): ...@@ -369,7 +399,9 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
r"""AmazonCoBuy dataset for node classification task. r"""AmazonCoBuy dataset for node classification task.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
- ``data`` is deprecated, it is repalced by:
>>> dataset = AmazonCoBuyPhotoDataset() >>> dataset = AmazonCoBuyPhotoDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
...@@ -377,14 +409,14 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset): ...@@ -377,14 +409,14 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
where nodes represent goods, edges indicate that two goods are frequently bought together, node where nodes represent goods, edges indicate that two goods are frequently bought together, node
features are bag-of-words encoded product reviews, and class labels are given by the product category. features are bag-of-words encoded product reviews, and class labels are given by the product category.
Reference: https://github.com/shchur/gnn-benchmark#datasets Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_
Statistics Statistics
----------
Nodes: 7,650 - Nodes: 7,650
Edges: 287,326 - Edges: 287,326
Number of classes: 5 - Number of classes: 5
Node feature size: 745 - Node feature size: 745
Parameters Parameters
---------- ----------
...@@ -419,7 +451,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset): ...@@ -419,7 +451,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
@property @property
def num_classes(self): def num_classes(self):
"""Number of classes.""" """Number of classes.
Return
-------
int
"""
return 5 return 5
......
...@@ -12,22 +12,23 @@ class ICEWS18Dataset(DGLBuiltinDataset): ...@@ -12,22 +12,23 @@ class ICEWS18Dataset(DGLBuiltinDataset):
r""" ICEWS18 dataset for temporal graph r""" ICEWS18 dataset for temporal graph
Integrated Crisis Early Warning System (ICEWS18) Integrated Crisis Early Warning System (ICEWS18)
Event data consists of coded interactions between socio-political Event data consists of coded interactions between socio-political
actors (i.e., cooperative or hostile actions between individuals, actors (i.e., cooperative or hostile actions between individuals,
groups, sectors and nation states). This Dataset consists of events groups, sectors and nation states). This Dataset consists of events
from 1/1/2018 to 10/31/2018 (24 hours time granularity). from 1/1/2018 to 10/31/2018 (24 hours time granularity).
Reference: Reference:
- `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
- `ICEWS Coded Event Data
<https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
Statistics - `Recurrent Event Network for Reasoning over Temporal Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
---------- - `ICEWS Coded Event Data <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
Train examples: 240
Valid examples: 30 Statistics:
Test examples: 34
Nodes per graph: 23033 - Train examples: 240
- Valid examples: 30
- Test examples: 34
- Nodes per graph: 23033
Parameters Parameters
---------- ----------
...@@ -111,14 +112,21 @@ class ICEWS18Dataset(DGLBuiltinDataset): ...@@ -111,14 +112,21 @@ class ICEWS18Dataset(DGLBuiltinDataset):
Returns Returns
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure and edge feature
- edata['rel_type']: edge type The graph contains:
- ``edata['rel_type']``: edge type
""" """
return self._graphs[idx] return self._graphs[idx]
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset""" r"""Number of graphs in the dataset.
Return
-------
int
"""
return len(self._graphs) return len(self._graphs)
@property @property
......
...@@ -15,7 +15,9 @@ class KarateClubDataset(DGLDataset): ...@@ -15,7 +15,9 @@ class KarateClubDataset(DGLDataset):
r""" Karate Club dataset for Node Classification r""" Karate Club dataset for Node Classification
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`data` is deprecated, it is replaced by:
``data`` is deprecated, it is replaced by:
>>> dataset = KarateClubDataset() >>> dataset = KarateClubDataset()
>>> g = dataset[0] >>> g = dataset[0]
...@@ -24,19 +26,20 @@ class KarateClubDataset(DGLDataset): ...@@ -24,19 +26,20 @@ class KarateClubDataset(DGLDataset):
Model for Conflict and Fission in Small Groups" by Wayne W. Zachary. Model for Conflict and Fission in Small Groups" by Wayne W. Zachary.
The network became a popular example of community structure in The network became a popular example of community structure in
networks after its use by Michelle Girvan and Mark Newman in 2002. networks after its use by Michelle Girvan and Mark Newman in 2002.
Official website: http://konect.cc/networks/ucidata-zachary/ Official website: `<http://konect.cc/networks/ucidata-zachary/>`_
Karate Club dataset statistics: Karate Club dataset statistics:
Nodes: 34
Edges: 156 - Nodes: 34
Number of Classes: 2 - Edges: 156
- Number of Classes: 2
Attributes Attributes
---------- ----------
num_classes : int num_classes : int
Number of node classes Number of node classes
data : list data : list
A list of DGLGraph objects A list of :class:`dgl.DGLGraph` objects
Examples Examples
-------- --------
...@@ -78,9 +81,11 @@ class KarateClubDataset(DGLDataset): ...@@ -78,9 +81,11 @@ class KarateClubDataset(DGLDataset):
Returns Returns
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure and labels. graph structure and labels.
- ndata['label']: ground truth labels
- ``ndata['label']``: ground truth labels
""" """
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
return self._graph return self._graph
......
...@@ -336,21 +336,27 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -336,21 +336,27 @@ class FB15k237Dataset(KnowledgeGraphDataset):
r"""FB15k237 link prediction dataset. r"""FB15k237 link prediction dataset.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`train` is deprecated, it is replaced by:
- ``train`` is deprecated, it is replaced by:
>>> dataset = FB15k237Dataset() >>> dataset = FB15k237Dataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.edata['train_mask'] >>> train_mask = graph.edata['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze() >>> train_idx = th.nonzero(train_mask).squeeze()
>>> src, dst = graph.edges(train_idx) >>> src, dst = graph.edges(train_idx)
>>> rel = graph.edata['etype'][train_idx] >>> rel = graph.edata['etype'][train_idx]
`valid` is deprecated, it is replaced by:
- ``valid`` is deprecated, it is replaced by:
>>> dataset = FB15k237Dataset() >>> dataset = FB15k237Dataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> val_mask = graph.edata['val_mask'] >>> val_mask = graph.edata['val_mask']
>>> val_idx = th.nonzero(val_mask).squeeze() >>> val_idx = th.nonzero(val_mask).squeeze()
>>> src, dst = graph.edges(val_idx) >>> src, dst = graph.edges(val_idx)
>>> rel = graph.edata['etype'][val_idx] >>> rel = graph.edata['etype'][val_idx]
`test` is deprecated, it is replaced by:
- ``test`` is deprecated, it is replaced by:
>>> dataset = FB15k237Dataset() >>> dataset = FB15k237Dataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.edata['test_mask'] >>> test_mask = graph.edata['test_mask']
...@@ -364,10 +370,15 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -364,10 +370,15 @@ class FB15k237Dataset(KnowledgeGraphDataset):
created for each edge by default. created for each edge by default.
FB15k237 dataset statistics: FB15k237 dataset statistics:
Nodes: 14541
Number of relation types: 237 - Nodes: 14541
Number of reversed relation types: 237 - Number of relation types: 237
Label Split: Train: 272115 ,Valid: 17535, Test: 20466 - Number of reversed relation types: 237
- Label Split:
- Train: 272115
- Valid: 17535
- Test: 20466
Parameters Parameters
---------- ----------
...@@ -387,11 +398,11 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -387,11 +398,11 @@ class FB15k237Dataset(KnowledgeGraphDataset):
Number of nodes Number of nodes
num_rels: int num_rels: int
Number of relation types Number of relation types
train: numpy array train: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the training graph A numpy array of triplets (src, rel, dst) for the training graph
valid: numpy array valid: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the validation graph A numpy array of triplets (src, rel, dst) for the validation graph
test: numpy array test: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the test graph A numpy array of triplets (src, rel, dst) for the test graph
Examples Examples
...@@ -421,7 +432,6 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -421,7 +432,6 @@ class FB15k237Dataset(KnowledgeGraphDataset):
>>> val_g.edata['e_type'] = e_type[val_edges]; >>> val_g.edata['e_type'] = e_type[val_edges];
>>> >>>
>>> # Train, Validation and Test >>> # Train, Validation and Test
>>>
""" """
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True): def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
name = 'FB15k-237' name = 'FB15k-237'
...@@ -437,16 +447,18 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -437,16 +447,18 @@ class FB15k237Dataset(KnowledgeGraphDataset):
Return Return
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
The graph contain
- edata['e_type']: edge relation type The graph contains
- edata['train_edge_mask']: positive training edge mask
- edata['val_edge_mask']: positive validation edge mask - ``edata['e_type']``: edge relation type
- edata['test_edge_mask']: positive testing edge mask - ``edata['train_edge_mask']``: positive training edge mask
- edata['train_mask']: training edge set mask (include reversed training edges) - ``edata['val_edge_mask']``: positive validation edge mask
- edata['val_mask']: validation edge set mask (include reversed validation edges) - ``edata['test_edge_mask']``: positive testing edge mask
- edata['test_mask']: testing edge set mask (include reversed testing edges) - ``edata['train_mask']``: training edge set mask (include reversed training edges)
- ndata['ntype']: node type. All 0 in this dataset - ``edata['val_mask']``: validation edge set mask (include reversed validation edges)
- ``edata['test_mask']``: testing edge set mask (include reversed testing edges)
- ``ndata['ntype']``: node type. All 0 in this dataset
""" """
return super(FB15k237Dataset, self).__getitem__(idx) return super(FB15k237Dataset, self).__getitem__(idx)
...@@ -458,21 +470,27 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -458,21 +470,27 @@ class FB15kDataset(KnowledgeGraphDataset):
r"""FB15k link prediction dataset. r"""FB15k link prediction dataset.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`train` is deprecated, it is replaced by:
- ``train`` is deprecated, it is replaced by:
>>> dataset = FB15kDataset() >>> dataset = FB15kDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.edata['train_mask'] >>> train_mask = graph.edata['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze() >>> train_idx = th.nonzero(train_mask).squeeze()
>>> src, dst = graph.edges(train_idx) >>> src, dst = graph.edges(train_idx)
>>> rel = graph.edata['etype'][train_idx] >>> rel = graph.edata['etype'][train_idx]
`valid` is deprecated, it is replaced by:
- ``valid`` is deprecated, it is replaced by:
>>> dataset = FB15kDataset() >>> dataset = FB15kDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> val_mask = graph.edata['val_mask'] >>> val_mask = graph.edata['val_mask']
>>> val_idx = th.nonzero(val_mask).squeeze() >>> val_idx = th.nonzero(val_mask).squeeze()
>>> src, dst = graph.edges(val_idx) >>> src, dst = graph.edges(val_idx)
>>> rel = graph.edata['etype'][val_idx] >>> rel = graph.edata['etype'][val_idx]
`test` is deprecated, it is replaced by:
- ``test`` is deprecated, it is replaced by:
>>> dataset = FB15kDataset() >>> dataset = FB15kDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.edata['test_mask'] >>> test_mask = graph.edata['test_mask']
...@@ -480,7 +498,8 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -480,7 +498,8 @@ class FB15kDataset(KnowledgeGraphDataset):
>>> src, dst = graph.edges(test_idx) >>> src, dst = graph.edges(test_idx)
>>> rel = graph.edata['etype'][test_idx] >>> rel = graph.edata['etype'][test_idx]
The FB15K dataset was introduced in http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf, The FB15K dataset was introduced in `Translating Embeddings for Modeling
Multi-relational Data <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_.
It is a subset of Freebase which contains about It is a subset of Freebase which contains about
14,951 entities with 1,345 different relations. 14,951 entities with 1,345 different relations.
When creating the dataset, a reverse edge with When creating the dataset, a reverse edge with
...@@ -488,10 +507,15 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -488,10 +507,15 @@ class FB15kDataset(KnowledgeGraphDataset):
by default. by default.
FB15k dataset statistics: FB15k dataset statistics:
Nodes: 14,951
Number of relation types: 1,345 - Nodes: 14,951
Number of reversed relation types: 1,345 - Number of relation types: 1,345
Label Split: Train: 483142 ,Valid: 50000, Test: 59071 - Number of reversed relation types: 1,345
- Label Split:
- Train: 483142
- Valid: 50000
- Test: 59071
Parameters Parameters
---------- ----------
...@@ -511,11 +535,11 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -511,11 +535,11 @@ class FB15kDataset(KnowledgeGraphDataset):
Number of nodes Number of nodes
num_rels: int num_rels: int
Number of relation types Number of relation types
train: numpy array train: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the training graph A numpy array of triplets (src, rel, dst) for the training graph
valid: numpy array valid: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the validation graph A numpy array of triplets (src, rel, dst) for the validation graph
test: numpy array test: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the test graph A numpy array of triplets (src, rel, dst) for the test graph
Examples Examples
...@@ -560,16 +584,18 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -560,16 +584,18 @@ class FB15kDataset(KnowledgeGraphDataset):
Return Return
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
The graph contain
- edata['e_type']: edge relation type The graph contains
- edata['train_edge_mask']: positive training edge mask
- edata['val_edge_mask']: positive validation edge mask - ``edata['e_type']``: edge relation type
- edata['test_edge_mask']: positive testing edge mask - ``edata['train_edge_mask']``: positive training edge mask
- edata['train_mask']: training edge set mask (include reversed training edges) - ``edata['val_edge_mask']``: positive validation edge mask
- edata['val_mask']: validation edge set mask (include reversed validation edges) - ``edata['test_edge_mask']``: positive testing edge mask
- edata['test_mask']: testing edge set mask (include reversed testing edges) - ``edata['train_mask']``: training edge set mask (include reversed training edges)
- ndata['ntype']: node type. All 0 in this dataset - ``edata['val_mask']``: validation edge set mask (include reversed validation edges)
- ``edata['test_mask']``: testing edge set mask (include reversed testing edges)
- ``ndata['ntype']``: node type. All 0 in this dataset
""" """
return super(FB15kDataset, self).__getitem__(idx) return super(FB15kDataset, self).__getitem__(idx)
...@@ -581,21 +607,27 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -581,21 +607,27 @@ class WN18Dataset(KnowledgeGraphDataset):
r""" WN18 link prediction dataset. r""" WN18 link prediction dataset.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`train` is deprecated, it is replaced by:
- ``train`` is deprecated, it is replaced by:
>>> dataset = WN18Dataset() >>> dataset = WN18Dataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.edata['train_mask'] >>> train_mask = graph.edata['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze() >>> train_idx = th.nonzero(train_mask).squeeze()
>>> src, dst = graph.edges(train_idx) >>> src, dst = graph.edges(train_idx)
>>> rel = graph.edata['etype'][train_idx] >>> rel = graph.edata['etype'][train_idx]
`valid` is deprecated, it is replaced by:
- ``valid`` is deprecated, it is replaced by:
>>> dataset = WN18Dataset() >>> dataset = WN18Dataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> val_mask = graph.edata['val_mask'] >>> val_mask = graph.edata['val_mask']
>>> val_idx = th.nonzero(val_mask).squeeze() >>> val_idx = th.nonzero(val_mask).squeeze()
>>> src, dst = graph.edges(val_idx) >>> src, dst = graph.edges(val_idx)
>>> rel = graph.edata['etype'][val_idx] >>> rel = graph.edata['etype'][val_idx]
`test` is deprecated, it is replaced by:
- ``test`` is deprecated, it is replaced by:
>>> dataset = WN18Dataset() >>> dataset = WN18Dataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.edata['test_mask'] >>> test_mask = graph.edata['test_mask']
...@@ -603,17 +635,23 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -603,17 +635,23 @@ class WN18Dataset(KnowledgeGraphDataset):
>>> src, dst = graph.edges(test_idx) >>> src, dst = graph.edges(test_idx)
>>> rel = graph.edata['etype'][test_idx] >>> rel = graph.edata['etype'][test_idx]
The WN18 dataset was introduced in http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf, The WN18 dataset was introduced in `Translating Embeddings for Modeling
Multi-relational Data <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_.
It included the full 18 relations scraped from It included the full 18 relations scraped from
WordNet for roughly 41,000 synsets. When creating WordNet for roughly 41,000 synsets. When creating
the dataset, a reverse edge with reversed relation the dataset, a reverse edge with reversed relation
types are created for each edge by default. types are created for each edge by default.
WN18 dataset tatistics: WN18 dataset statistics:
Nodes: 40943
Number of relation types: 18 - Nodes: 40943
Number of reversed relation types: 18 - Number of relation types: 18
Label Split: Train: 141442 ,Valid: 5000, Test: 5000 - Number of reversed relation types: 18
- Label Split:
- Train: 141442
- Valid: 5000
- Test: 5000
Parameters Parameters
---------- ----------
...@@ -633,11 +671,11 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -633,11 +671,11 @@ class WN18Dataset(KnowledgeGraphDataset):
Number of nodes Number of nodes
num_rels: int num_rels: int
Number of relation types Number of relation types
train: numpy array train: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the training graph A numpy array of triplets (src, rel, dst) for the training graph
valid: numpy array valid: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the validation graph A numpy array of triplets (src, rel, dst) for the validation graph
test: numpy array test: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the test graph A numpy array of triplets (src, rel, dst) for the test graph
Examples Examples
...@@ -682,16 +720,18 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -682,16 +720,18 @@ class WN18Dataset(KnowledgeGraphDataset):
Return Return
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
The graph contain
- edata['e_type']: edge relation type The graph contains
- edata['train_edge_mask']: positive training edge mask
- edata['val_edge_mask']: positive validation edge mask - ``edata['e_type']``: edge relation type
- edata['test_edge_mask']: positive testing edge mask - ``edata['train_edge_mask']``: positive training edge mask
- edata['train_mask']: training edge set mask (include reversed training edges) - ``edata['val_edge_mask']``: positive validation edge mask
- edata['val_mask']: validation edge set mask (include reversed validation edges) - ``edata['test_edge_mask']``: positive testing edge mask
- edata['test_mask']: testing edge set mask (include reversed testing edges) - ``edata['train_mask']``: training edge set mask (include reversed training edges)
- ndata['ntype']: node type. All 0 in this dataset - ``edata['val_mask']``: validation edge set mask (include reversed validation edges)
- ``edata['test_mask']``: testing edge set mask (include reversed testing edges)
- ``ndata['ntype']``: node type. All 0 in this dataset
""" """
return super(WN18Dataset, self).__getitem__(idx) return super(WN18Dataset, self).__getitem__(idx)
......
...@@ -12,17 +12,18 @@ from ..transform import add_self_loop ...@@ -12,17 +12,18 @@ from ..transform import add_self_loop
__all__ = ['MiniGCDataset'] __all__ = ['MiniGCDataset']
class MiniGCDataset(DGLDataset): class MiniGCDataset(DGLDataset):
"""The dataset class. """The synthetic graph classification dataset class.
The datset contains 8 different types of graphs. The datset contains 8 different types of graphs.
* class 0 : cycle graph
* class 1 : star graph - class 0 : cycle graph
* class 2 : wheel graph - class 1 : star graph
* class 3 : lollipop graph - class 2 : wheel graph
* class 4 : hypercube graph - class 3 : lollipop graph
* class 5 : grid graph - class 4 : hypercube graph
* class 6 : clique graph - class 5 : grid graph
* class 7 : circular ladder graph - class 6 : clique graph
- class 7 : circular ladder graph
Parameters Parameters
---------- ----------
...@@ -50,7 +51,7 @@ class MiniGCDataset(DGLDataset): ...@@ -50,7 +51,7 @@ class MiniGCDataset(DGLDataset):
-------- --------
>>> data = MiniGCDataset(100, 16, 32, seed=0) >>> data = MiniGCDataset(100, 16, 32, seed=0)
**The dataset instance is an iterable** The dataset instance is an iterable
>>> len(data) >>> len(data)
100 100
...@@ -62,7 +63,7 @@ class MiniGCDataset(DGLDataset): ...@@ -62,7 +63,7 @@ class MiniGCDataset(DGLDataset):
>>> label >>> label
tensor(5) tensor(5)
**Batch the graphs and labels for mini-batch training** Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)]) >>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs) >>> batched_graphs = dgl.batch(graphs)
...@@ -96,13 +97,15 @@ class MiniGCDataset(DGLDataset): ...@@ -96,13 +97,15 @@ class MiniGCDataset(DGLDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the idx-th sample. """Get the idx-th sample.
Paramters
Parameters
--------- ---------
idx : int idx : int
The sample index. The sample index.
Returns Returns
------- -------
(dgl.Graph, int) (:class:`dgl.Graph`, Tensor)
The graph and its label. The graph and its label.
""" """
return self.graphs[idx], self.labels[idx] return self.graphs[idx], self.labels[idx]
...@@ -143,7 +146,7 @@ class MiniGCDataset(DGLDataset): ...@@ -143,7 +146,7 @@ class MiniGCDataset(DGLDataset):
self._gen_circular_ladder(self.num_graphs - len(self.graphs)) self._gen_circular_ladder(self.num_graphs - len(self.graphs))
# preprocess # preprocess
for i in range(self.num_graphs): for i in range(self.num_graphs):
# convert to Graph, and add self loops # convert to DGLGraph, and add self loops
self.graphs[i] = add_self_loop(dgl_graph(self.graphs[i])) self.graphs[i] = add_self_loop(dgl_graph(self.graphs[i]))
self.labels = F.tensor(np.array(self.labels).astype(np.int)) self.labels = F.tensor(np.array(self.labels).astype(np.int))
......
...@@ -15,13 +15,17 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -15,13 +15,17 @@ class PPIDataset(DGLBuiltinDataset):
r""" Protein-Protein Interaction dataset for inductive node classification r""" Protein-Protein Interaction dataset for inductive node classification
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`lables` is deprecated, it is replaced by:
- ``lables`` is deprecated, it is replaced by:
>>> dataset = PPIDataset() >>> dataset = PPIDataset()
>>> for g in dataset: >>> for g in dataset:
.... labels = g.ndata['label'] .... labels = g.ndata['label']
.... ....
>>> >>>
`features` is deprecated, it is replaced by:
- ``features`` is deprecated, it is replaced by:
>>> dataset = PPIDataset() >>> dataset = PPIDataset()
>>> for g in dataset: >>> for g in dataset:
.... features = g.ndata['feat'] .... features = g.ndata['feat']
...@@ -33,12 +37,13 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -33,12 +37,13 @@ class PPIDataset(DGLBuiltinDataset):
50 features and 121 labels. 20 graphs for training, 2 for validation 50 features and 121 labels. 20 graphs for training, 2 for validation
and 2 for testing. and 2 for testing.
Reference: http://snap.stanford.edu/graphsage/ Reference: `<http://snap.stanford.edu/graphsage/>`_
Statistics:
PPI dataset statistics: - Train examples: 20
Train examples: 20 - Valid examples: 2
Valid examples: 2 - Test examples: 2
Test examples: 2
Parameters Parameters
---------- ----------
...@@ -167,10 +172,11 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -167,10 +172,11 @@ class PPIDataset(DGLBuiltinDataset):
Returns Returns
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, node features and node labels. graph structure, node features and node labels.
- ndata['feat']: node features
- ndata['label']: nodel labels - ``ndata['feat']``: node features
- ``ndata['label']``: node labels
""" """
return self.graphs[item] return self.graphs[item]
......
...@@ -16,15 +16,16 @@ class QM7bDataset(DGLDataset): ...@@ -16,15 +16,16 @@ class QM7bDataset(DGLDataset):
This dataset consists of 7,211 molecules with 14 regression targets. This dataset consists of 7,211 molecules with 14 regression targets.
Nodes means atoms and edges means bonds. Edge data 'h' means Nodes means atoms and edges means bonds. Edge data 'h' means
the entry of Coulomb matrix. the entry of Coulomb matrix.
Reference: http://quantum-machine.org/datasets/
Statistics Reference: `<http://quantum-machine.org/datasets/>`_
----------
Number of graphs: 7,211 Statistics:
Number of regression targets: 14
Average number of nodes: 15 - Number of graphs: 7,211
Average number of edges: 245 - Number of regression targets: 14
Edge feature size: 1 - Average number of nodes: 15
- Average number of edges: 245
- Edge feature size: 1
Parameters Parameters
---------- ----------
...@@ -73,10 +74,6 @@ class QM7bDataset(DGLDataset): ...@@ -73,10 +74,6 @@ class QM7bDataset(DGLDataset):
def process(self): def process(self):
mat_path = self.raw_path + '.mat' mat_path = self.raw_path + '.mat'
if not check_sha1(mat_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name))
self.graphs, self.label = self._load_graph(mat_path) self.graphs, self.label = self._load_graph(mat_path)
def _load_graph(self, filename): def _load_graph(self, filename):
...@@ -110,6 +107,10 @@ class QM7bDataset(DGLDataset): ...@@ -110,6 +107,10 @@ class QM7bDataset(DGLDataset):
def download(self): def download(self):
file_path = os.path.join(self.raw_dir, self.name + '.mat') file_path = os.path.join(self.raw_dir, self.name + '.mat')
download(self.url, path=file_path) download(self.url, path=file_path)
if not check_sha1(file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name))
@property @property
def num_labels(self): def num_labels(self):
...@@ -126,12 +127,17 @@ class QM7bDataset(DGLDataset): ...@@ -126,12 +127,17 @@ class QM7bDataset(DGLDataset):
Returns Returns
------- -------
(dgl.DGLGraph, Tensor) (:class:`dgl.DGLGraph`, Tensor)
""" """
return self.graphs[idx], self.label[idx] return self.graphs[idx], self.label[idx]
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset""" r"""Number of graphs in the dataset.
Return
-------
int
"""
return len(self.graphs) return len(self.graphs)
......
...@@ -8,7 +8,10 @@ from collections import OrderedDict ...@@ -8,7 +8,10 @@ from collections import OrderedDict
import itertools import itertools
import abc import abc
import re import re
import rdflib as rdf try:
import rdflib as rdf
except ImportError:
pass
import networkx as nx import networkx as nx
import numpy as np import numpy as np
...@@ -112,6 +115,7 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -112,6 +115,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True):
import rdflib as rdf
self._insert_reverse = insert_reverse self._insert_reverse = insert_reverse
self._print_every = print_every self._print_every = print_every
self._predict_category = predict_category self._predict_category = predict_category
...@@ -542,15 +546,21 @@ class AIFBDataset(RDFGraphDataset): ...@@ -542,15 +546,21 @@ class AIFBDataset(RDFGraphDataset):
r"""AIFB dataset for node classification task r"""AIFB dataset for node classification task
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = AIFBDataset() >>> dataset = AIFBDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
`train_idx` is deprecated, it can be replaced by:
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = AIFBDataset() >>> dataset = AIFBDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask'] >>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze() >>> train_idx = th.nonzero(train_mask).squeeze()
`test_idx` is deprecated, it can be replaced by:
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = AIFBDataset() >>> dataset = AIFBDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask'] >>> test_mask = graph.nodes[dataset.category].data['test_mask']
...@@ -561,11 +571,15 @@ class AIFBDataset(RDFGraphDataset): ...@@ -561,11 +571,15 @@ class AIFBDataset(RDFGraphDataset):
University of Karlsruhe. University of Karlsruhe.
AIFB dataset statistics: AIFB dataset statistics:
Nodes: 7262
Edges: 48810 (including reverse edges) - Nodes: 7262
Target Category: Personen - Edges: 48810 (including reverse edges)
Number of Classes: 4 - Target Category: Personen
Label Split: Train: 140, Test: 36 - Number of Classes: 4
- Label Split:
- Train: 140
- Test: 36
Parameters Parameters
----------- -----------
...@@ -589,7 +603,7 @@ class AIFBDataset(RDFGraphDataset): ...@@ -589,7 +603,7 @@ class AIFBDataset(RDFGraphDataset):
The entity category (node type) that has labels for prediction The entity category (node type) that has labels for prediction
labels : Tensor labels : Tensor
All the labels of the entities in ``predict_category`` All the labels of the entities in ``predict_category``
graph : dgl.DGLGraph graph : :class:`dgl.DGLGraph`
Graph structure Graph structure
train_idx : Tensor train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``. Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
...@@ -608,8 +622,6 @@ class AIFBDataset(RDFGraphDataset): ...@@ -608,8 +622,6 @@ class AIFBDataset(RDFGraphDataset):
>>> labels = g.nodes[category].data.pop('labels') >>> labels = g.nodes[category].data.pop('labels')
""" """
employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
entity_prefix = 'http://www.aifb.uni-karlsruhe.de/' entity_prefix = 'http://www.aifb.uni-karlsruhe.de/'
relation_prefix = 'http://swrc.ontoware.org/' relation_prefix = 'http://swrc.ontoware.org/'
...@@ -619,6 +631,9 @@ class AIFBDataset(RDFGraphDataset): ...@@ -619,6 +631,9 @@ class AIFBDataset(RDFGraphDataset):
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True):
import rdflib as rdf
self.employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
self.affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
url = _get_dgl_url('dataset/rdf/aifb-hetero.zip') url = _get_dgl_url('dataset/rdf/aifb-hetero.zip')
name = 'aifb-hetero' name = 'aifb-hetero'
predict_category = 'Personen' predict_category = 'Personen'
...@@ -639,16 +654,23 @@ class AIFBDataset(RDFGraphDataset): ...@@ -639,16 +654,23 @@ class AIFBDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set The graph contains:
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
""" """
return super(AIFBDataset, self).__getitem__(idx) return super(AIFBDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(AIFBDataset, self).__len__() return super(AIFBDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
...@@ -703,26 +725,36 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -703,26 +725,36 @@ class MUTAGDataset(RDFGraphDataset):
r"""MUTAG dataset for node classification task r"""MUTAG dataset for node classification task
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = MUTAGDataset() >>> dataset = MUTAGDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
`train_idx` is deprecated, it can be replaced by:
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = MUTAGDataset() >>> dataset = MUTAGDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask'] >>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze() >>> train_idx = th.nonzero(train_mask).squeeze()
`test_idx` is deprecated, it can be replaced by:
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = MUTAGDataset() >>> dataset = MUTAGDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask'] >>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze() >>> test_idx = th.nonzero(test_mask).squeeze()
Mutag dataset statistics: Mutag dataset statistics:
Nodes: 27163
Edges: 148100 (including reverse edges) - Nodes: 27163
Target Category: d - Edges: 148100 (including reverse edges)
Number of Classes: 2 - Target Category: d
Label Split: Train: 272, Test: 68 - Number of Classes: 2
- Label Split:
- Train: 272
- Test: 68
Parameters Parameters
----------- -----------
...@@ -746,7 +778,7 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -746,7 +778,7 @@ class MUTAGDataset(RDFGraphDataset):
The entity category (node type) that has labels for prediction The entity category (node type) that has labels for prediction
labels : Tensor labels : Tensor
All the labels of the entities in ``predict_category`` All the labels of the entities in ``predict_category``
graph : dgl.DGLGraph graph : :class:`dgl.DGLGraph`
Graph structure Graph structure
train_idx : Tensor train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``. Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
...@@ -768,11 +800,6 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -768,11 +800,6 @@ class MUTAGDataset(RDFGraphDataset):
d_entity = re.compile("d[0-9]") d_entity = re.compile("d[0-9]")
bond_entity = re.compile("bond[0-9]") bond_entity = re.compile("bond[0-9]")
is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
rdf_subclassof = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#subClassOf")
rdf_domain = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#domain")
entity_prefix = 'http://dl-learner.org/carcinogenesis#' entity_prefix = 'http://dl-learner.org/carcinogenesis#'
relation_prefix = entity_prefix relation_prefix = entity_prefix
...@@ -782,6 +809,12 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -782,6 +809,12 @@ class MUTAGDataset(RDFGraphDataset):
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True):
import rdflib as rdf
self.is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
self.rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
self.rdf_subclassof = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#subClassOf")
self.rdf_domain = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#domain")
url = _get_dgl_url('dataset/rdf/mutag-hetero.zip') url = _get_dgl_url('dataset/rdf/mutag-hetero.zip')
name = 'mutag-hetero' name = 'mutag-hetero'
predict_category = 'd' predict_category = 'd'
...@@ -802,16 +835,23 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -802,16 +835,23 @@ class MUTAGDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set The graph contains:
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
""" """
return super(MUTAGDataset, self).__getitem__(idx) return super(MUTAGDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(MUTAGDataset, self).__len__() return super(MUTAGDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
...@@ -882,32 +922,42 @@ class BGSDataset(RDFGraphDataset): ...@@ -882,32 +922,42 @@ class BGSDataset(RDFGraphDataset):
r"""BGS dataset for node classification task r"""BGS dataset for node classification task
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = BGSDataset() >>> dataset = BGSDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
`train_idx` is deprecated, it can be replaced by:
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = BGSDataset() >>> dataset = BGSDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask'] >>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze() >>> train_idx = th.nonzero(train_mask).squeeze()
`test_idx` is deprecated, it can be replaced by:
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = BGSDataset() >>> dataset = BGSDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask'] >>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze() >>> test_idx = th.nonzero(test_mask).squeeze()
BGS namespace convention: BGS namespace convention:
http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE ``http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE``.
We ignored all literal nodes and the relations connecting them in the We ignored all literal nodes and the relations connecting them in the
output graph. We also ignored the relation used to mark whether a output graph. We also ignored the relation used to mark whether a
term is CURRENT or DEPRECATED. term is CURRENT or DEPRECATED.
BGS dataset statistics: BGS dataset statistics:
Nodes: 94806
Edges: 672884 (including reverse edges) - Nodes: 94806
Target Category: Lexicon/NamedRockUnit - Edges: 672884 (including reverse edges)
Number of Classes: 2 - Target Category: Lexicon/NamedRockUnit
Label Split: Train: 117, Test: 29 - Number of Classes: 2
- Label Split:
- Train: 117
- Test: 29
Parameters Parameters
----------- -----------
...@@ -931,7 +981,7 @@ class BGSDataset(RDFGraphDataset): ...@@ -931,7 +981,7 @@ class BGSDataset(RDFGraphDataset):
The entity category (node type) that has labels for prediction The entity category (node type) that has labels for prediction
labels : Tensor labels : Tensor
All the labels of the entities in ``predict_category`` All the labels of the entities in ``predict_category``
graph : dgl.DGLGraph graph : :class:`dgl.DGLGraph`
Graph structure Graph structure
train_idx : Tensor train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``. Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
...@@ -950,7 +1000,6 @@ class BGSDataset(RDFGraphDataset): ...@@ -950,7 +1000,6 @@ class BGSDataset(RDFGraphDataset):
>>> labels = g.nodes[category].data.pop('labels') >>> labels = g.nodes[category].data.pop('labels')
""" """
lith = rdf.term.URIRef("http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis")
entity_prefix = 'http://data.bgs.ac.uk/' entity_prefix = 'http://data.bgs.ac.uk/'
status_prefix = 'http://data.bgs.ac.uk/ref/CurrentStatus' status_prefix = 'http://data.bgs.ac.uk/ref/CurrentStatus'
relation_prefix = 'http://data.bgs.ac.uk/ref' relation_prefix = 'http://data.bgs.ac.uk/ref'
...@@ -961,9 +1010,11 @@ class BGSDataset(RDFGraphDataset): ...@@ -961,9 +1010,11 @@ class BGSDataset(RDFGraphDataset):
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True):
import rdflib as rdf
url = _get_dgl_url('dataset/rdf/bgs-hetero.zip') url = _get_dgl_url('dataset/rdf/bgs-hetero.zip')
name = 'bgs-hetero' name = 'bgs-hetero'
predict_category = 'Lexicon/NamedRockUnit' predict_category = 'Lexicon/NamedRockUnit'
self.lith = rdf.term.URIRef("http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis")
super(BGSDataset, self).__init__(name, url, predict_category, super(BGSDataset, self).__init__(name, url, predict_category,
print_every=print_every, print_every=print_every,
insert_reverse=insert_reverse, insert_reverse=insert_reverse,
...@@ -981,16 +1032,23 @@ class BGSDataset(RDFGraphDataset): ...@@ -981,16 +1032,23 @@ class BGSDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set The graph contains:
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
""" """
return super(BGSDataset, self).__getitem__(idx) return super(BGSDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(BGSDataset, self).__len__() return super(BGSDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
...@@ -1057,32 +1115,44 @@ class AMDataset(RDFGraphDataset): ...@@ -1057,32 +1115,44 @@ class AMDataset(RDFGraphDataset):
"""AM dataset. for node classification task """AM dataset. for node classification task
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = AMDataset() >>> dataset = AMDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
`train_idx` is deprecated, it can be replaced by:
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = AMDataset() >>> dataset = AMDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask'] >>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze() >>> train_idx = th.nonzero(train_mask).squeeze()
`test_idx` is deprecated, it can be replaced by:
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = AMDataset() >>> dataset = AMDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask'] >>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze() >>> test_idx = th.nonzero(test_mask).squeeze()
Namespace convention: Namespace convention:
Instance: http://purl.org/collections/nl/am/<type>-<id>
Relation: http://purl.org/collections/nl/am/<name> - Instance: ``http://purl.org/collections/nl/am/<type>-<id>``
- Relation: ``http://purl.org/collections/nl/am/<name>``
We ignored all literal nodes and the relations connecting them in the We ignored all literal nodes and the relations connecting them in the
output graph. output graph.
AM dataset statistics: AM dataset statistics:
Nodes: 881680
Edges: 5668682 (including reverse edges) - Nodes: 881680
Target Category: proxy - Edges: 5668682 (including reverse edges)
Number of Classes: 11 - Target Category: proxy
Label Split: Train: 802, Test: 198 - Number of Classes: 11
- Label Split:
- Train: 802
- Test: 198
Parameters Parameters
----------- -----------
...@@ -1106,7 +1176,7 @@ class AMDataset(RDFGraphDataset): ...@@ -1106,7 +1176,7 @@ class AMDataset(RDFGraphDataset):
The entity category (node type) that has labels for prediction The entity category (node type) that has labels for prediction
labels : Tensor labels : Tensor
All the labels of the entities in ``predict_category`` All the labels of the entities in ``predict_category``
graph : dgl.DGLGraph graph : :class:`dgl.DGLGraph`
Graph structure Graph structure
train_idx : Tensor train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``. Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
...@@ -1125,8 +1195,6 @@ class AMDataset(RDFGraphDataset): ...@@ -1125,8 +1195,6 @@ class AMDataset(RDFGraphDataset):
>>> labels = g.nodes[category].data.pop('labels') >>> labels = g.nodes[category].data.pop('labels')
""" """
objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
entity_prefix = 'http://purl.org/collections/nl/am/' entity_prefix = 'http://purl.org/collections/nl/am/'
relation_prefix = entity_prefix relation_prefix = entity_prefix
...@@ -1136,6 +1204,9 @@ class AMDataset(RDFGraphDataset): ...@@ -1136,6 +1204,9 @@ class AMDataset(RDFGraphDataset):
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True):
import rdflib as rdf
self.objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
self.material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
url = _get_dgl_url('dataset/rdf/am-hetero.zip') url = _get_dgl_url('dataset/rdf/am-hetero.zip')
name = 'am-hetero' name = 'am-hetero'
predict_category = 'proxy' predict_category = 'proxy'
...@@ -1156,16 +1227,23 @@ class AMDataset(RDFGraphDataset): ...@@ -1156,16 +1227,23 @@ class AMDataset(RDFGraphDataset):
Return Return
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set The graph contains:
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels - ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
""" """
return super(AMDataset, self).__getitem__(idx) return super(AMDataset, self).__getitem__(idx)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(AMDataset, self).__len__() return super(AMDataset, self).__len__()
def parse_entity(self, term): def parse_entity(self, term):
......
...@@ -15,29 +15,43 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -15,29 +15,43 @@ class RedditDataset(DGLBuiltinDataset):
r""" Reddit dataset for community detection (node classification) r""" Reddit dataset for community detection (node classification)
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = RedditDataset() >>> dataset = RedditDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
`num_labels` is deprecated, it is replaced by:
- ``num_labels`` is deprecated, it is replaced by:
>>> dataset = RedditDataset() >>> dataset = RedditDataset()
>>> num_classes = dataset.num_classes >>> num_classes = dataset.num_classes
`train_mask` is deprecated, it is replaced by:
- ``train_mask`` is deprecated, it is replaced by:
>>> dataset = RedditDataset() >>> dataset = RedditDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask'] >>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
- ``val_mask`` is deprecated, it is replaced by:
>>> dataset = RedditDataset() >>> dataset = RedditDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask'] >>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
- ``test_mask`` is deprecated, it is replaced by:
>>> dataset = RedditDataset() >>> dataset = RedditDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask'] >>> test_mask = graph.ndata['test_mask']
`features` is deprecated, it is replaced by:
- ``features`` is deprecated, it is replaced by:
>>> dataset = RedditDataset() >>> dataset = RedditDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> features = graph.ndata['feat'] >>> features = graph.ndata['feat']
`labels` is deprecated, it is replaced by:
- ``labels`` is deprecated, it is replaced by:
>>> dataset = RedditDataset() >>> dataset = RedditDataset()
>>> graph = dataset[0] >>> graph = dataset[0]
>>> labels = graph.ndata['label'] >>> labels = graph.ndata['label']
...@@ -49,16 +63,16 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -49,16 +63,16 @@ class RedditDataset(DGLBuiltinDataset):
posts with an average degree of 492. We use the first 20 days for training and the posts with an average degree of 492. We use the first 20 days for training and the
remaining days for testing (with 30% used for validation). remaining days for testing (with 30% used for validation).
Reference: http://snap.stanford.edu/graphsage/ Reference: `<http://snap.stanford.edu/graphsage/>`_
Statistics Statistics
----------
Nodes: 232,965 - Nodes: 232,965
Edges: 114,615,892 - Edges: 114,615,892
Node feature size: 602 - Node feature size: 602
Number of training samples: 153,431 - Number of training samples: 153,431
Number of validation samples: 23,831 - Number of validation samples: 23,831
Number of test samples: 55,703 - Number of test samples: 55,703
Parameters Parameters
---------- ----------
...@@ -76,7 +90,7 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -76,7 +90,7 @@ class RedditDataset(DGLBuiltinDataset):
---------- ----------
num_classes : int num_classes : int
Number of classes for each node Number of classes for each node
graph : dgl.DGLGraph graph : :class:`dgl.DGLGraph`
Graph of the dataset Graph of the dataset
num_labels : int num_labels : int
Number of classes for each node Number of classes for each node
...@@ -220,13 +234,14 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -220,13 +234,14 @@ class RedditDataset(DGLBuiltinDataset):
Returns Returns
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, node labels, node features and splitting masks graph structure, node labels, node features and splitting masks:
- ndata['label']: node label
- ndata['feat']: node feature - ``ndata['label']``: node label
- ndata['train_mask']: mask for training node set - ``ndata['feat']``: node feature
- ndata['val_mask']: mask for validation node set - ``ndata['train_mask']``: mask for training node set
- ndata['test_mask']: mask for test node set - ``ndata['val_mask']``: mask for validation node set
- ``ndata['test_mask']:`` mask for test node set
""" """
assert idx == 0, "Reddit Dataset only has one graph" assert idx == 0, "Reddit Dataset only has one graph"
return self._graph return self._graph
......
...@@ -59,8 +59,7 @@ def sbm(n_blocks, block_size, p, q, rng=None): ...@@ -59,8 +59,7 @@ def sbm(n_blocks, block_size, p, q, rng=None):
class SBMMixtureDataset(DGLDataset): class SBMMixtureDataset(DGLDataset):
r""" Symmetric Stochastic Block Model Mixture r""" Symmetric Stochastic Block Model Mixture
Reference: Appendix C of "Supervised Community Detection with Hierarchical Reference: Appendix C of `Supervised Community Detection with Hierarchical Graph Neural Networks <https://arxiv.org/abs/1705.08415>`_
Graph Neural Networks" (https://arxiv.org/abs/1705.08415).
Parameters Parameters
---------- ----------
...@@ -175,15 +174,15 @@ class SBMMixtureDataset(DGLDataset): ...@@ -175,15 +174,15 @@ class SBMMixtureDataset(DGLDataset):
Returns Returns
------- -------
graph : dgl.DGLGraph graph: :class:`dgl.DGLGraph`
The original graph The original graph
line_graph : dgl.DGLGraph line_graph: :class:`dgl.DGLGraph`
The line graph of `graph` The line graph of `graph`
graph_degree : numpy.ndarray graph_degree: numpy.ndarray
In degrees for each node in `graph` In degrees for each node in `graph`
line_graph_degree : numpy.ndarray line_graph_degree: numpy.ndarray
In degrees for each node in `line_graph` In degrees for each node in `line_graph`
pm_pd : numpy.ndarray pm_pd: numpy.ndarray
Edge indicator matrices Pm and Pd Edge indicator matrices Pm and Pd
""" """
return self._graphs[idx], self._line_graphs[idx], \ return self._graphs[idx], self._line_graphs[idx], \
...@@ -203,29 +202,30 @@ class SBMMixtureDataset(DGLDataset): ...@@ -203,29 +202,30 @@ class SBMMixtureDataset(DGLDataset):
Parameters Parameters
---------- ----------
x : tuple x : tuple
a batch of data that contains a batch of data that contains:
graph : dgl.DGLGraph
- graph: :class:`dgl.DGLGraph`
The original graph The original graph
line_graph : dgl.DGLGraph - line_graph: :class:`dgl.DGLGraph`
The line graph of `graph` The line graph of `graph`
graph_degree : numpy.ndarray - graph_degree: numpy.ndarray
In degrees for each node in `graph` In degrees for each node in `graph`
line_graph_degree : numpy.ndarray - line_graph_degree: numpy.ndarray
In degrees for each node in `line_graph` In degrees for each node in `line_graph`
pm_pd : numpy.ndarray - pm_pd: numpy.ndarray
Edge indicator matrices Pm and Pd Edge indicator matrices Pm and Pd
Returns Returns
------- -------
g_batch : dgl.DGLGraph g_batch: :class:`dgl.DGLGraph`
Batched graphs Batched graphs
lg_batch : dgl.DGLGraph lg_batch: :class:`dgl.DGLGraph`
Batched line graphs Batched line graphs
degg_batch : numpy.ndarray degg_batch: numpy.ndarray
A batch of in degrees for each node in `g_batch` A batch of in degrees for each node in `g_batch`
deglg_batch : numpy.ndarray deglg_batch: numpy.ndarray
A batch of in degrees for each node in `lg_batch` A batch of in degrees for each node in `lg_batch`
pm_pd_batch : numpy.ndarray pm_pd_batch: numpy.ndarray
A batch of edge indicator matrices Pm and Pd A batch of edge indicator matrices Pm and Pd
""" """
g, lg, deg_g, deg_lg, pm_pd = zip(*x) g, lg, deg_g, deg_lg, pm_pd = zip(*x)
......
...@@ -23,13 +23,14 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -23,13 +23,14 @@ class SSTDataset(DGLBuiltinDataset):
r"""Stanford Sentiment Treebank dataset. r"""Stanford Sentiment Treebank dataset.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
`trees` is deprecated, it is replaced by:
- ``trees`` is deprecated, it is replaced by:
>>> dataset = SSTDataset() >>> dataset = SSTDataset()
>>> for tree in dataset: >>> for tree in dataset:
.... # your code here .... # your code here
....
>>> - ``num_vocabs`` is deprecated, it is replaced by ``vocab_size``.
`num_vocabs` is deprecated, it is replaced by `vocab_size`
Each sample is the constituency tree of a sentence. The leaf nodes Each sample is the constituency tree of a sentence. The leaf nodes
represent words. The word is a int value stored in the ``x`` feature field. represent words. The word is a int value stored in the ``x`` feature field.
...@@ -37,14 +38,14 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -37,14 +38,14 @@ class SSTDataset(DGLBuiltinDataset):
Each node also has a sentiment annotation: 5 classes (very negative, Each node also has a sentiment annotation: 5 classes (very negative,
negative, neutral, positive and very positive). The sentiment label is a negative, neutral, positive and very positive). The sentiment label is a
int value stored in the ``y`` feature field. int value stored in the ``y`` feature field.
Official site: http://nlp.stanford.edu/sentiment/index.html Official site: `<http://nlp.stanford.edu/sentiment/index.html>`_
Statistics Statistics:
----------
Train examples: 8,544 - Train examples: 8,544
Dev examples: 1,101 - Dev examples: 1,101
Test examples: 2,210 - Test examples: 2,210
Number of classes for each node: 5 - Number of classes for each node: 5
Parameters Parameters
---------- ----------
...@@ -100,16 +101,14 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -100,16 +101,14 @@ class SSTDataset(DGLBuiltinDataset):
>>> train_data.vocab_size >>> train_data.vocab_size
19536 19536
>>> train_data[0] >>> train_data[0]
DGLGraph(num_nodes=71, num_edges=70, Graph(num_nodes=71, num_edges=70,
ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)} edata_schemes={})
edata_schemes={})
>>> for tree in train_data: >>> for tree in train_data:
... input_ids = tree.ndata['x'] ... input_ids = tree.ndata['x']
... labels = tree.ndata['y'] ... labels = tree.ndata['y']
... mask = tree.ndata['mask'] ... mask = tree.ndata['mask']
... # your code here ... # your code here
>>>
""" """
PAD_WORD = -1 # special pad word id PAD_WORD = -1 # special pad word id
...@@ -247,11 +246,13 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -247,11 +246,13 @@ class SSTDataset(DGLBuiltinDataset):
Returns Returns
------- -------
dgl.DGLGraph :class:`dgl.DGLGraph`
graph structure, word id for each node, node labels and masks
- ndata['x']: word id of the node graph structure, word id for each node, node labels and masks.
- ndata['y']: label of the node
- ndata['mask']: 1 if the node is a leaf, otherwise 0 - ``ndata['x']``: word id of the node
- ``ndata['y']:`` label of the node
- ``ndata['mask']``: 1 if the node is a leaf, otherwise 0
""" """
return self._trees[idx] return self._trees[idx]
......
...@@ -16,17 +16,18 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -16,17 +16,18 @@ class LegacyTUDataset(DGLBuiltinDataset):
Parameters Parameters
---------- ----------
name : str name : str
Dataset Name, such as `ENZYMES`, `DD`, `COLLAB` Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
use_pandas : bool use_pandas : bool
Numpy's file read function has performance issue when file is large, Numpy's file read function has performance issue when file is large,
using pandas can be faster. using pandas can be faster.
Default: False Default: False
hidden_size : int hidden_size : int
Some dataset doesn't contain features. Some dataset doesn't contain features.
Use constant node features initialization instead, with hidden size as `hidden_size`. Use constant node features initialization instead, with hidden size as ``hidden_size``.
Default : 10 Default : 10
max_allow_node : int max_allow_node : int
Remove graphs that contains more nodes than `max_allow_node`. Remove graphs that contains more nodes than ``max_allow_node``.
Default : None Default : None
Attributes Attributes
...@@ -40,7 +41,7 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -40,7 +41,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
-------- --------
>>> data = LegacyTUDataset('DD') >>> data = LegacyTUDataset('DD')
**The dataset instance is an iterable** The dataset instance is an iterable
>>> len(data) >>> len(data)
1178 1178
...@@ -52,7 +53,7 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -52,7 +53,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
>>> label >>> label
tensor(1) tensor(1)
**Batch the graphs and labels for mini-batch training* Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)]) >>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs) >>> batched_graphs = dgl.batch(graphs)
...@@ -190,20 +191,23 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -190,20 +191,23 @@ class LegacyTUDataset(DGLBuiltinDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the idx-th sample. """Get the idx-th sample.
Paramters
Parameters
--------- ---------
idx : int idx : int
The sample index. The sample index.
Returns Returns
------- -------
(dgl.Graph, int) (:class:`dgl.DGLGraph`, Tensor)
Graph with node feature stored in `feat` field and node label in `node_label` if available. Graph with node feature stored in ``feat`` field and node label in ``node_label`` if available.
And its label. And its label.
""" """
g = self.graph_lists[idx] g = self.graph_lists[idx]
return g, self.graph_labels[idx] return g, self.graph_labels[idx]
def __len__(self): def __len__(self):
"""Return the number of graphs in the dataset."""
return len(self.graph_lists) return len(self.graph_lists)
def _file_path(self, category): def _file_path(self, category):
...@@ -234,8 +238,8 @@ class TUDataset(DGLBuiltinDataset): ...@@ -234,8 +238,8 @@ class TUDataset(DGLBuiltinDataset):
Parameters Parameters
---------- ----------
name : str name : str
Dataset Name, such as `ENZYMES`, `DD`, `COLLAB`, `MUTAG`, can be the Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets. datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
Attributes Attributes
---------- ----------
...@@ -248,7 +252,7 @@ class TUDataset(DGLBuiltinDataset): ...@@ -248,7 +252,7 @@ class TUDataset(DGLBuiltinDataset):
-------- --------
>>> data = TUDataset('DD') >>> data = TUDataset('DD')
**The dataset instance is an iterable** The dataset instance is an iterable
>>> len(data) >>> len(data)
188 188
...@@ -260,7 +264,7 @@ class TUDataset(DGLBuiltinDataset): ...@@ -260,7 +264,7 @@ class TUDataset(DGLBuiltinDataset):
>>> label >>> label
tensor([1]) tensor([1])
**Batch the graphs and labels for mini-batch training* Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)]) >>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs) >>> batched_graphs = dgl.batch(graphs)
...@@ -355,20 +359,23 @@ class TUDataset(DGLBuiltinDataset): ...@@ -355,20 +359,23 @@ class TUDataset(DGLBuiltinDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the idx-th sample. """Get the idx-th sample.
Paramters
Parameters
--------- ---------
idx : int idx : int
The sample index. The sample index.
Returns Returns
------- -------
(dgl.Graph, int) (:class:`dgl.DGLGraph`, Tensor)
Graph with node feature stored in `feat` field and node label in `node_label` if available. Graph with node feature stored in ``feat`` field and node label in ``node_label`` if available.
And its label. And its label.
""" """
g = self.graph_lists[idx] g = self.graph_lists[idx]
return g, self.graph_labels[idx] return g, self.graph_labels[idx]
def __len__(self): def __len__(self):
"""Return the number of graphs in the dataset."""
return len(self.graph_lists) return len(self.graph_lists)
def _file_path(self, category): def _file_path(self, category):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment