Unverified Commit 967ecb80 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Dataset] Fix the docstring format for dgl.data section (#1941)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* update data rst

* update data doc and docstring

* API doc rst for dataset

* docstring

* update api doc

* add url format

* update docstring

* update citation graph

* update knowledge graph

* update gc datasets

* fix index

* Rst fix (#3)

* Fix syntax

* syntax

* update docstring

* update doc (#4)

* final update

* fix rdflib

* fix rdf
Co-authored-by: default avatarHuXiangkun <huxk_hit@qq.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 3fa8d755
......@@ -160,3 +160,4 @@ cscope.*
config.cmake
.ycm_extra_conf.py
**.png
......@@ -5,89 +5,125 @@ dgl.data
.. currentmodule:: dgl.data
Utils
-----
Dataset Classes
---------------
.. autosummary::
:toctree: ../../generated/
DGL dataset
```````````
utils.get_download_dir
utils.download
utils.check_sha1
utils.extract_archive
utils.split_dataset
utils.save_graphs
utils.load_graphs
utils.load_labels
.. autoclass:: DGLDataset
:members: download, save, load, process, has_cache, __getitem__, __len__
.. autoclass:: dgl.data.utils.Subset
:members: __getitem__, __len__
DGL builtin dataset
```````````````````
Dataset Classes
---------------
.. autoclass:: DGLBuiltinDataset
:members: download
Stanford sentiment treebank dataset
```````````````````````````````````
For more information about the dataset, see `Sentiment Analysis <https://nlp.stanford.edu/sentiment/index.html>`__.
.. autoclass:: SST
.. autoclass:: SSTDataset
:members: __getitem__, __len__
Karate Club dataset
Karate club dataset
```````````````````````````````````
.. autoclass:: KarateClub
.. autoclass:: KarateClubDataset
:members: __getitem__, __len__
Citation Network dataset
Citation network dataset
```````````````````````````````````
.. autoclass:: CitationGraphDataset
.. autoclass:: CoraGraphDataset
:members: __getitem__, __len__
.. autoclass:: CiteseerGraphDataset
:members: __getitem__, __len__
.. autoclass:: PubmedGraphDataset
:members: __getitem__, __len__
Knowlege graph dataset
```````````````````````````````````
.. autoclass:: FB15k237Dataset
:members: __getitem__, __len__
.. autoclass:: FB15kDataset
:members: __getitem__, __len__
.. autoclass:: WN18Dataset
:members: __getitem__, __len__
RDF datasets
```````````````````````````````````
.. autoclass:: AIFBDataset
:members: __getitem__, __len__
.. autoclass:: MUTAGDataset
:members: __getitem__, __len__
.. autoclass:: BGSDataset
:members: __getitem__, __len__
.. autoclass:: AMDataset
:members: __getitem__, __len__
CoraFull dataset
```````````````````````````````````
.. autoclass:: CoraFull
.. autoclass:: CoraFullDataset
:members: __getitem__, __len__
Amazon Co-Purchase dataset
```````````````````````````````````
.. autoclass:: AmazonCoBuy
.. autoclass:: AmazonCoBuyComputerDataset
:members: __getitem__, __len__
.. autoclass:: AmazonCoBuyPhotoDataset
:members: __getitem__, __len__
Coauthor dataset
```````````````````````````````````
.. autoclass:: Coauthor
.. autoclass:: CoauthorCSDataset
:members: __getitem__, __len__
.. autoclass:: CoauthorPhysicsDataset
:members: __getitem__, __len__
BitcoinOTC dataset
```````````````````````````````````
.. autoclass:: BitcoinOTC
.. autoclass:: BitcoinOTCDataset
:members: __getitem__, __len__
ICEWS18 dataset
```````````````````````````````````
.. autoclass:: ICEWS18
.. autoclass:: ICEWS18Dataset
:members: __getitem__, __len__
QM7b dataset
```````````````````````````````````
.. autoclass:: QM7b
.. autoclass:: QM7bDataset
:members: __getitem__, __len__
......@@ -95,7 +131,7 @@ QM7b dataset
GDELT dataset
```````````````````````````````````
.. autoclass:: GDELT
.. autoclass:: GDELTDataset
:members: __getitem__, __len__
......@@ -103,17 +139,16 @@ Mini graph classification dataset
`````````````````````````````````
.. autoclass:: MiniGCDataset
:members: __getitem__, __len__, num_classes
Graph kernel dataset
````````````````````
:members: __getitem__, __len__
For more information about the dataset, see `Benchmark Data Sets for Graph Kernels <https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets>`__.
TU dataset
``````````
.. autoclass:: TUDataset
:members: __getitem__, __len__
.. autoclass:: LegacyTUDataset
:members: __getitem__, __len__
Graph isomorphism network dataset
```````````````````````````````````
......@@ -129,3 +164,36 @@ Protein-Protein Interaction dataset
.. autoclass:: PPIDataset
:members: __getitem__, __len__
Reddit dataset
``````````````
.. autoclass:: RedditDataset
:members: __getitem__, __len__
Symmetric Stochastic Block Model Mixture dataset
````````````````````````````````````````````````
.. autoclass:: SBMMixtureDataset
:members: __getitem__, __len__, collate_fn
Utils
-----
.. autosummary::
:toctree: ../../generated/
utils.get_download_dir
utils.download
utils.check_sha1
utils.extract_archive
utils.split_dataset
utils.save_graphs
utils.load_graphs
utils.load_labels
.. autoclass:: dgl.data.utils.Subset
:members: __getitem__, __len__
......@@ -20,6 +20,8 @@ from .icews18 import ICEWS18, ICEWS18Dataset
from .qm7b import QM7b, QM7bDataset
from .dgl_dataset import DGLDataset, DGLBuiltinDataset
from .citation_graph import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset
from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
def register_data_args(parser):
......
......@@ -18,13 +18,15 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
a platform called Bitcoin OTC. Since Bitcoin users are anonymous,
there is a need to maintain a record of users' reputation to prevent
transactions with fraudulent and risky users.
Offical website: https://snap.stanford.edu/data/soc-sign-bitcoin-otc.html
Offical website: `<https://snap.stanford.edu/data/soc-sign-bitcoin-otc.html>`_
Bitcoin OTC dataset statistics:
Nodes: 5,881
Edges: 35,592
Range of edge weight: -10 to +10
Percentage of positive edges: 89%
- Nodes: 5,881
- Edges: 35,592
- Range of edge weight: -10 to +10
- Percentage of positive edges: 89%
Parameters
----------
......@@ -117,7 +119,12 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
return self._graphs
def __len__(self):
r""" Number of graphs in the dataset """
r""" Number of graphs in the dataset.
Return
-------
int
"""
return len(self.graphs)
def __getitem__(self, item):
......@@ -130,9 +137,11 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
Returns
-------
dgl.DGLGraph
The graph contains the graph structure and edge weights
- edata['h'] : edge weights
:class:`dgl.DGLGraph`
The graph contains:
- ``edata['h']`` : edge weights
"""
return self.graphs[item]
......
......@@ -273,26 +273,38 @@ class CoraGraphDataset(CitationGraphDataset):
r""" Cora citation network dataset.
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
- ``train_mask`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
- ``val_mask`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
- ``test_mask`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
- ``labels`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
- ``feat`` is deprecated, it is replaced by:
>>> dataset = CoraGraphDataset()
>>> graph = dataset[0]
>>> feat = graph.ndata['feat']
......@@ -304,12 +316,16 @@ class CoraGraphDataset(CitationGraphDataset):
The task is to predict the category of
certain paper.
Statistics
----------
Nodes: 2708
Edges: 10556
Number of Classes: 7
Label Split: Train: 140 ,Valid: 500, Test: 1000
Statistics:
- Nodes: 2708
- Edges: 10556
- Number of Classes: 7
- Label split:
- Train: 140
- Valid: 500
- Test: 1000
Parameters
----------
......@@ -319,7 +335,7 @@ class CoraGraphDataset(CitationGraphDataset):
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Whether to print out progress information. Default: True.
Attributes
----------
......@@ -327,13 +343,13 @@ class CoraGraphDataset(CitationGraphDataset):
Number of label classes
graph: networkx.DiGraph
Graph structure
train_mask: Numpy array
train_mask: numpy.ndarray
Mask of training nodes
val_mask: Numpy array
val_mask: numpy.ndarray
Mask of validation nodes
test_mask: Numpy array
test_mask: numpy.ndarray
Mask of test nodes
labels: Numpy array
labels: numpy.ndarray
Ground truth labels of each node
features: Tensor
Node features
......@@ -377,13 +393,15 @@ class CoraGraphDataset(CitationGraphDataset):
Return
------
dgl.DGLGraph
:class:`dgl.DGLGraph`
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
- ndata['feat']: node feature
- ndata['label']: ground truth labels
- ``ndata['train_mask']``: mask for training node set
- ``ndata['val_mask']``: mask for validation node set
- ``ndata['test_mask']``: mask for test node set
- ``ndata['feat']``: node feature
- ``ndata['label']``: ground truth labels
"""
return super(CoraGraphDataset, self).__getitem__(idx)
......@@ -395,26 +413,38 @@ class CiteseerGraphDataset(CitationGraphDataset):
r""" Citeseer citation network dataset.
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
``graph`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
``train_mask`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
``val_mask`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
``test_mask`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
``labels`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
``feat`` is deprecated, it is replaced by:
>>> dataset = CiteseerGraphDataset()
>>> graph = dataset[0]
>>> feat = graph.ndata['feat']
......@@ -426,12 +456,16 @@ class CiteseerGraphDataset(CitationGraphDataset):
task. The task is to predict the category of
certain publication.
Statistics
----------
Nodes: 3327
Edges: 9228
Number of Classes: 6
Label Split: Train: 120 ,Valid: 500, Test: 1000
Statistics:
- Nodes: 3327
- Edges: 9228
- Number of Classes: 6
- Label Split:
- Train: 120
- Valid: 500
- Test: 1000
Parameters
-----------
......@@ -441,7 +475,7 @@ class CiteseerGraphDataset(CitationGraphDataset):
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Whether to print out progress information. Default: True.
Attributes
----------
......@@ -449,13 +483,13 @@ class CiteseerGraphDataset(CitationGraphDataset):
Number of label classes
graph: networkx.DiGraph
Graph structure
train_mask: Numpy array
train_mask: numpy.ndarray
Mask of training nodes
val_mask: Numpy array
val_mask: numpy.ndarray
Mask of validation nodes
test_mask: Numpy array
test_mask: numpy.ndarray
Mask of test nodes
labels: Numpy array
labels: numpy.ndarray
Ground truth labels of each node
features: Tensor
Node features
......@@ -502,13 +536,15 @@ class CiteseerGraphDataset(CitationGraphDataset):
Return
------
dgl.DGLGraph
:class:`dgl.DGLGraph`
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
- ndata['feat']: node feature
- ndata['label']: ground truth labels
- ``ndata['train_mask']``: mask for training node set
- ``ndata['val_mask']``: mask for validation node set
- ``ndata['test_mask']``: mask for test node set
- ``ndata['feat']``: node feature
- ``ndata['label']``: ground truth labels
"""
return super(CiteseerGraphDataset, self).__getitem__(idx)
......@@ -520,26 +556,38 @@ class PubmedGraphDataset(CitationGraphDataset):
r""" Pubmed citation network dataset.
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
``graph`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
`train_mask` is deprecated, it is replaced by:
``train_mask`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
``val_mask`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
``test_mask`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`labels` is deprecated, it is replaced by:
``labels`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
`feat` is deprecated, it is replaced by:
``feat`` is deprecated, it is replaced by:
>>> dataset = PubmedGraphDataset()
>>> graph = dataset[0]
>>> feat = graph.ndata['feat']
......@@ -551,12 +599,16 @@ class PubmedGraphDataset(CitationGraphDataset):
task. The task is to predict the category of
certain publication.
Statistics
----------
Nodes: 19717
Edges: 88651
Number of Classes: 3
Label Split: Train: 60 ,Valid: 500, Test: 1000
Statistics:
- Nodes: 19717
- Edges: 88651
- Number of Classes: 3
- Label Split:
- Train: 60
- Valid: 500
- Test: 1000
Parameters
-----------
......@@ -566,7 +618,7 @@ class PubmedGraphDataset(CitationGraphDataset):
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Whether to print out progress information. Default: True.
Attributes
----------
......@@ -574,13 +626,13 @@ class PubmedGraphDataset(CitationGraphDataset):
Number of label classes
graph: networkx.DiGraph
Graph structure
train_mask: Numpy array
train_mask: numpy.ndarray
Mask of training nodes
val_mask: Numpy array
val_mask: numpy.ndarray
Mask of validation nodes
test_mask: Numpy array
test_mask: numpy.ndarray
Mask of test nodes
labels: Numpy array
labels: numpy.ndarray
Ground truth labels of each node
features: Tensor
Node features
......@@ -624,13 +676,15 @@ class PubmedGraphDataset(CitationGraphDataset):
Return
------
dgl.DGLGraph
:class:`dgl.DGLGraph`
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
- ndata['feat']: node feature
- ndata['label']: ground truth labels
- ``ndata['train_mask']``: mask for training node set
- ``ndata['val_mask']``: mask for validation node set
- ``ndata['test_mask']``: mask for test node set
- ``ndata['feat']``: node feature
- ``ndata['label']``: ground truth labels
"""
return super(PubmedGraphDataset, self).__getitem__(idx)
......
......@@ -14,13 +14,13 @@ class DGLDataset(object):
The following steps will are executed automatically:
1. Check whether there is a dataset cache on disk
(already processed and stored on the disk) by
invoking ``has_cache()``. If true, goto 5.
(already processed and stored on the disk) by
invoking ``has_cache()``. If true, goto 5.
2. Call ``download()`` to download the data.
3. Call ``process()`` to process the data.
4. Call ``save()`` to save the processed dataset on disk and goto 6.
5. Call ``load()`` to load the processed dataset from disk.
6. Done
6. Done.
Users can overwite these functions with their
own data processing logic.
......@@ -43,11 +43,31 @@ class DGLDataset(object):
A tuple of values as the input for the hash function.
Users can distinguish instances (and their caches on the disk)
from the same dataset class by comparing the hash values.
Default: (), the corresponding hash value is 'f9065fa7'.
Default: (), the corresponding hash value is ``'f9065fa7'``.
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information
Attributes
----------
url : str
The URL to download the dataset
name : str
The dataset name
raw_dir : str
Raw file directory contains the input data folder
raw_path : str
Directory contains the input data files.
Default : ``os.path.join(self.raw_dir, self.name)``
save_dir : str
Directory to save the processed dataset
save_path : str
Filr path to save the processed dataset
verbose : bool
Whether to print information
hash : str
Hash value for the dataset and the setting.
"""
def __init__(self, name, url=None, raw_dir=None, save_dir=None,
hash_key=(), force_reload=False, verbose=False):
......@@ -75,7 +95,7 @@ class DGLDataset(object):
It is recommended to download the to the :obj:`self.raw_dir`
folder. Can be ignored if the dataset is
already in self.raw_dir
already in :obj:`self.raw_dir`.
"""
pass
......@@ -83,9 +103,9 @@ class DGLDataset(object):
r"""Overwite to realize your own logic of
saving the processed dataset into files.
It is recommended to use dgl.utils.data.save_graphs
It is recommended to use ``dgl.utils.data.save_graphs``
to save dgl graph into files and use
dgl.utils.data.save_info to save extra
``dgl.utils.data.save_info`` to save extra
information into files.
"""
pass
......@@ -94,9 +114,9 @@ class DGLDataset(object):
r"""Overwite to realize your own logic of
loading the saved dataset from files.
It is recommended to use dgl.utils.data.load_graphs
It is recommended to use ``dgl.utils.data.load_graphs``
to load dgl graph from files and use
dgl.utils.data.load_info to load extra information
``dgl.utils.data.load_info`` to load extra information
into python dict object.
"""
pass
......@@ -116,9 +136,9 @@ class DGLDataset(object):
@retry_method_with_fix(download)
def _download(self):
r"""Download dataset by calling self.download() if the dataset does not exists under self.raw_path.
By default self.raw_path = os.path.join(self.raw_dir, self.name)
One can overwrite raw_path() function to change the path.
r"""Download dataset by calling ``self.download()`` if the dataset does not exists under ``self.raw_path``.
By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``
One can overwrite ``raw_path()`` function to change the path.
"""
if os.path.exists(self.raw_path): # pragma: no cover
return
......@@ -185,13 +205,13 @@ class DGLDataset(object):
@property
def raw_dir(self):
r"""Raw file directory contains the input data directory.
r"""Raw file directory contains the input data folder.
"""
return self._raw_dir
@property
def raw_path(self):
r"""File directory contains the input data.
r"""Directory contains the input data files.
By default raw_path = os.path.join(self.raw_dir, self.name)
"""
return os.path.join(self.raw_dir, self.name)
......@@ -216,7 +236,7 @@ class DGLDataset(object):
@property
def hash(self):
r"""Hash value for the dataset.
r"""Hash value for the dataset and the setting.
"""
return self._hash
......@@ -235,6 +255,7 @@ class DGLBuiltinDataset(DGLDataset):
r"""The Basic DGL Builtin Dataset.
Parameters
----------
name : str
Name of the dataset.
url : str
......
......@@ -18,16 +18,15 @@ class GDELTDataset(DGLBuiltinDataset):
(15 minutes time granularity).
Reference:
- `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs` <https://arxiv.org/abs/1904.05530>
- `The Global Database of Events, Language, and Tone (GDELT) `
<https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>
Statistics
----------
Train examples: 2,304
Valid examples: 288
Test examples: 384
- `Recurrent Event Network for Reasoning over Temporal Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
- `The Global Database of Events, Language, and Tone (GDELT) <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
Statistics:
- Train examples: 2,304
- Valid examples: 288
- Test examples: 384
Parameters
----------
......@@ -135,9 +134,11 @@ class GDELTDataset(DGLBuiltinDataset):
Returns
-------
dgl.DGLGraph
graph structure and edge feature
- edata['rel_type']: edge type
:class:`dgl.DGLGraph`
The graph contains:
- ``edata['rel_type']``: edge type
"""
if t >= len(self) or t < 0:
raise IndexError("Index out of range")
......@@ -150,7 +151,12 @@ class GDELTDataset(DGLBuiltinDataset):
return g
def __len__(self):
r"""Number of graphs in the dataset"""
r"""Number of graphs in the dataset.
Return
-------
int
"""
return self._end_time - self._start_time + 1
@property
......
......@@ -17,21 +17,19 @@ from ..convert import graph as dgl_graph
class GINDataset(DGLBuiltinDataset):
"""Datasets for Graph Isomorphism Network (GIN)
Adapted from https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip.
Adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.
The dataset contains the compact format of popular graph kernel datasets, which includes:
MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K
This datset class processes all data sets listed above. For more graph kernel datasets,
see :class:`TUDataset`.
The dataset contains the compact format of popular graph kernel datasets.
For more graph kernel datasets, see :class:`TUDataset`.
Paramters
Parameters
---------
name: str
dataset name, one of below -
('MUTAG', 'COLLAB', \
'IMDBBINARY', 'IMDBMULTI', \
'NCI1', 'PROTEINS', 'PTC', \
'REDDITBINARY', 'REDDITMULTI5K')
dataset name, one of
(``'MUTAG'``, ``'COLLAB'``, \
``'IMDBBINARY'``, ``'IMDBMULTI'``, \
``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, \
``'REDDITBINARY'``, ``'REDDITMULTI5K'``)
self_loop: bool
add self to self edge if true
degree_as_nlabel: bool
......@@ -41,7 +39,7 @@ class GINDataset(DGLBuiltinDataset):
--------
>>> data = GINDataset(name='MUTAG', self_loop=False)
**The dataset instance is an iterable**
The dataset instance is an iterable
>>> len(data)
188
......@@ -53,7 +51,7 @@ class GINDataset(DGLBuiltinDataset):
>>> label
tensor(1)
**Batch the graphs and labels for mini-batch training**
Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
......@@ -118,14 +116,14 @@ class GINDataset(DGLBuiltinDataset):
def __getitem__(self, idx):
"""Get the idx-th sample.
Paramters
Parameters
---------
idx : int
The sample index.
Returns
-------
(dgl.Graph, int)
(:class:`dgl.Graph`, Tensor)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
......
......@@ -118,10 +118,12 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
Returns
-------
dgl.DGLGraph
graph structure, node features and node labels
- ndata['feat']: node features
- ndata['label']: node labels
:class:`dgl.DGLGraph`
The graph contains:
- ``ndata['feat']``: node features
- ``ndata['label']``: node labels
"""
assert idx == 0, "This dataset has only one graph"
return self._graph
......@@ -135,7 +137,9 @@ class CoraFullDataset(GNNBenchmarkDataset):
r"""CORA-Full dataset for node classification task.
.. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
- ``data`` is deprecated, it is repalced by:
>>> dataset = CoraFullDataset()
>>> graph = dataset[0]
......@@ -143,14 +147,14 @@ class CoraFullDataset(GNNBenchmarkDataset):
Unsupervised Inductive Learning via Ranking`.
Nodes represent paper and edges represent citations.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_
Statistics
----------
Nodes: 19,793
Edges: 130,622
Number of Classes: 70
Node feature size: 8,710
Statistics:
- Nodes: 19,793
- Edges: 130,622
- Number of Classes: 70
- Node feature size: 8,710
Parameters
----------
......@@ -185,7 +189,12 @@ class CoraFullDataset(GNNBenchmarkDataset):
@property
def num_classes(self):
"""Number of classes."""
"""Number of classes.
Return
-------
int
"""
return 70
......@@ -193,7 +202,9 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
r""" 'Computer Science (CS)' part of the Coauthor dataset for node classification task.
.. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
- ``data`` is deprecated, it is repalced by:
>>> dataset = CoauthorCSDataset()
>>> graph = dataset[0]
......@@ -202,14 +213,14 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
co-authored a paper; node features represent paper keywords for each author’s papers, and class
labels indicate most active fields of study for each author.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_
Statistics
----------
Nodes: 18,333
Edges: 327,576
Number of classes: 15
Node feature size: 6,805
Statistics:
- Nodes: 18,333
- Edges: 327,576
- Number of classes: 15
- Node feature size: 6,805
Parameters
----------
......@@ -244,7 +255,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
@property
def num_classes(self):
"""Number of classes."""
"""Number of classes.
Return
-------
int
"""
return 15
......@@ -252,7 +268,9 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
r""" 'Physics' part of the Coauthor dataset for node classification task.
.. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
- ``data`` is deprecated, it is repalced by:
>>> dataset = CoauthorPhysicsDataset()
>>> graph = dataset[0]
......@@ -261,14 +279,14 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
co-authored a paper; node features represent paper keywords for each author’s papers, and class
labels indicate most active fields of study for each author.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_
Statistics
----------
Nodes: 34,493
Edges: 991,848
Number of classes: 5
Node feature size: 8,415
- Nodes: 34,493
- Edges: 991,848
- Number of classes: 5
- Node feature size: 8,415
Parameters
----------
......@@ -303,7 +321,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
@property
def num_classes(self):
"""Number of classes."""
"""Number of classes.
Return
-------
int
"""
return 5
......@@ -311,7 +334,9 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
r""" 'Computer' part of the AmazonCoBuy dataset for node classification task.
.. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
- ``data`` is deprecated, it is repalced by:
>>> dataset = AmazonCoBuyComputerDataset()
>>> graph = dataset[0]
......@@ -319,14 +344,14 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
where nodes represent goods, edges indicate that two goods are frequently bought together, node
features are bag-of-words encoded product reviews, and class labels are given by the product category.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_
Statistics
----------
Nodes: 13,752
Edges: 574,418
Number of classes: 5
Node feature size: 767
Statistics:
- Nodes: 13,752
- Edges: 574,418
- Number of classes: 5
- Node feature size: 767
Parameters
----------
......@@ -361,7 +386,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
@property
def num_classes(self):
"""Number of classes."""
"""Number of classes.
Return
-------
int
"""
return 5
......@@ -369,7 +399,9 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
r"""AmazonCoBuy dataset for node classification task.
.. deprecated:: 0.5.0
`data` is deprecated, it is repalced by:
- ``data`` is deprecated, it is repalced by:
>>> dataset = AmazonCoBuyPhotoDataset()
>>> graph = dataset[0]
......@@ -377,14 +409,14 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
where nodes represent goods, edges indicate that two goods are frequently bought together, node
features are bag-of-words encoded product reviews, and class labels are given by the product category.
Reference: https://github.com/shchur/gnn-benchmark#datasets
Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_
Statistics
----------
Nodes: 7,650
Edges: 287,326
Number of classes: 5
Node feature size: 745
- Nodes: 7,650
- Edges: 287,326
- Number of classes: 5
- Node feature size: 745
Parameters
----------
......@@ -419,7 +451,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
@property
def num_classes(self):
"""Number of classes."""
"""Number of classes.
Return
-------
int
"""
return 5
......
......@@ -12,22 +12,23 @@ class ICEWS18Dataset(DGLBuiltinDataset):
r""" ICEWS18 dataset for temporal graph
Integrated Crisis Early Warning System (ICEWS18)
Event data consists of coded interactions between socio-political
actors (i.e., cooperative or hostile actions between individuals,
groups, sectors and nation states). This Dataset consists of events
from 1/1/2018 to 10/31/2018 (24 hours time granularity).
Reference:
- `Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
- `ICEWS Coded Event Data
<https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
Statistics
----------
Train examples: 240
Valid examples: 30
Test examples: 34
Nodes per graph: 23033
- `Recurrent Event Network for Reasoning over Temporal Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_
- `ICEWS Coded Event Data <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_
Statistics:
- Train examples: 240
- Valid examples: 30
- Test examples: 34
- Nodes per graph: 23033
Parameters
----------
......@@ -111,14 +112,21 @@ class ICEWS18Dataset(DGLBuiltinDataset):
Returns
-------
dgl.DGLGraph
graph structure and edge feature
- edata['rel_type']: edge type
:class:`dgl.DGLGraph`
The graph contains:
- ``edata['rel_type']``: edge type
"""
return self._graphs[idx]
def __len__(self):
r"""Number of graphs in the dataset"""
r"""Number of graphs in the dataset.
Return
-------
int
"""
return len(self._graphs)
@property
......
......@@ -15,7 +15,9 @@ class KarateClubDataset(DGLDataset):
r""" Karate Club dataset for Node Classification
.. deprecated:: 0.5.0
`data` is deprecated, it is replaced by:
``data`` is deprecated, it is replaced by:
>>> dataset = KarateClubDataset()
>>> g = dataset[0]
......@@ -24,19 +26,20 @@ class KarateClubDataset(DGLDataset):
Model for Conflict and Fission in Small Groups" by Wayne W. Zachary.
The network became a popular example of community structure in
networks after its use by Michelle Girvan and Mark Newman in 2002.
Official website: http://konect.cc/networks/ucidata-zachary/
Official website: `<http://konect.cc/networks/ucidata-zachary/>`_
Karate Club dataset statistics:
Nodes: 34
Edges: 156
Number of Classes: 2
- Nodes: 34
- Edges: 156
- Number of Classes: 2
Attributes
----------
num_classes : int
Number of node classes
data : list
A list of DGLGraph objects
A list of :class:`dgl.DGLGraph` objects
Examples
--------
......@@ -78,9 +81,11 @@ class KarateClubDataset(DGLDataset):
Returns
-------
dgl.DGLGraph
:class:`dgl.DGLGraph`
graph structure and labels.
- ndata['label']: ground truth labels
- ``ndata['label']``: ground truth labels
"""
assert idx == 0, "This dataset has only one graph"
return self._graph
......
......@@ -336,21 +336,27 @@ class FB15k237Dataset(KnowledgeGraphDataset):
r"""FB15k237 link prediction dataset.
.. deprecated:: 0.5.0
`train` is deprecated, it is replaced by:
- ``train`` is deprecated, it is replaced by:
>>> dataset = FB15k237Dataset()
>>> graph = dataset[0]
>>> train_mask = graph.edata['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
>>> src, dst = graph.edges(train_idx)
>>> rel = graph.edata['etype'][train_idx]
`valid` is deprecated, it is replaced by:
- ``valid`` is deprecated, it is replaced by:
>>> dataset = FB15k237Dataset()
>>> graph = dataset[0]
>>> val_mask = graph.edata['val_mask']
>>> val_idx = th.nonzero(val_mask).squeeze()
>>> src, dst = graph.edges(val_idx)
>>> rel = graph.edata['etype'][val_idx]
`test` is deprecated, it is replaced by:
- ``test`` is deprecated, it is replaced by:
>>> dataset = FB15k237Dataset()
>>> graph = dataset[0]
>>> test_mask = graph.edata['test_mask']
......@@ -364,10 +370,15 @@ class FB15k237Dataset(KnowledgeGraphDataset):
created for each edge by default.
FB15k237 dataset statistics:
Nodes: 14541
Number of relation types: 237
Number of reversed relation types: 237
Label Split: Train: 272115 ,Valid: 17535, Test: 20466
- Nodes: 14541
- Number of relation types: 237
- Number of reversed relation types: 237
- Label Split:
- Train: 272115
- Valid: 17535
- Test: 20466
Parameters
----------
......@@ -387,11 +398,11 @@ class FB15k237Dataset(KnowledgeGraphDataset):
Number of nodes
num_rels: int
Number of relation types
train: numpy array
train: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the training graph
valid: numpy array
valid: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the validation graph
test: numpy array
test: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the test graph
Examples
......@@ -421,7 +432,6 @@ class FB15k237Dataset(KnowledgeGraphDataset):
>>> val_g.edata['e_type'] = e_type[val_edges];
>>>
>>> # Train, Validation and Test
>>>
"""
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
name = 'FB15k-237'
......@@ -437,16 +447,18 @@ class FB15k237Dataset(KnowledgeGraphDataset):
Return
-------
dgl.DGLGraph
The graph contain
- edata['e_type']: edge relation type
- edata['train_edge_mask']: positive training edge mask
- edata['val_edge_mask']: positive validation edge mask
- edata['test_edge_mask']: positive testing edge mask
- edata['train_mask']: training edge set mask (include reversed training edges)
- edata['val_mask']: validation edge set mask (include reversed validation edges)
- edata['test_mask']: testing edge set mask (include reversed testing edges)
- ndata['ntype']: node type. All 0 in this dataset
:class:`dgl.DGLGraph`
The graph contains
- ``edata['e_type']``: edge relation type
- ``edata['train_edge_mask']``: positive training edge mask
- ``edata['val_edge_mask']``: positive validation edge mask
- ``edata['test_edge_mask']``: positive testing edge mask
- ``edata['train_mask']``: training edge set mask (include reversed training edges)
- ``edata['val_mask']``: validation edge set mask (include reversed validation edges)
- ``edata['test_mask']``: testing edge set mask (include reversed testing edges)
- ``ndata['ntype']``: node type. All 0 in this dataset
"""
return super(FB15k237Dataset, self).__getitem__(idx)
......@@ -458,21 +470,27 @@ class FB15kDataset(KnowledgeGraphDataset):
r"""FB15k link prediction dataset.
.. deprecated:: 0.5.0
`train` is deprecated, it is replaced by:
- ``train`` is deprecated, it is replaced by:
>>> dataset = FB15kDataset()
>>> graph = dataset[0]
>>> train_mask = graph.edata['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
>>> src, dst = graph.edges(train_idx)
>>> rel = graph.edata['etype'][train_idx]
`valid` is deprecated, it is replaced by:
- ``valid`` is deprecated, it is replaced by:
>>> dataset = FB15kDataset()
>>> graph = dataset[0]
>>> val_mask = graph.edata['val_mask']
>>> val_idx = th.nonzero(val_mask).squeeze()
>>> src, dst = graph.edges(val_idx)
>>> rel = graph.edata['etype'][val_idx]
`test` is deprecated, it is replaced by:
- ``test`` is deprecated, it is replaced by:
>>> dataset = FB15kDataset()
>>> graph = dataset[0]
>>> test_mask = graph.edata['test_mask']
......@@ -480,7 +498,8 @@ class FB15kDataset(KnowledgeGraphDataset):
>>> src, dst = graph.edges(test_idx)
>>> rel = graph.edata['etype'][test_idx]
The FB15K dataset was introduced in http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf,
The FB15K dataset was introduced in `Translating Embeddings for Modeling
Multi-relational Data <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_.
It is a subset of Freebase which contains about
14,951 entities with 1,345 different relations.
When creating the dataset, a reverse edge with
......@@ -488,10 +507,15 @@ class FB15kDataset(KnowledgeGraphDataset):
by default.
FB15k dataset statistics:
Nodes: 14,951
Number of relation types: 1,345
Number of reversed relation types: 1,345
Label Split: Train: 483142 ,Valid: 50000, Test: 59071
- Nodes: 14,951
- Number of relation types: 1,345
- Number of reversed relation types: 1,345
- Label Split:
- Train: 483142
- Valid: 50000
- Test: 59071
Parameters
----------
......@@ -511,11 +535,11 @@ class FB15kDataset(KnowledgeGraphDataset):
Number of nodes
num_rels: int
Number of relation types
train: numpy array
train: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the training graph
valid: numpy array
valid: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the validation graph
test: numpy array
test: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the test graph
Examples
......@@ -560,16 +584,18 @@ class FB15kDataset(KnowledgeGraphDataset):
Return
-------
dgl.DGLGraph
The graph contain
- edata['e_type']: edge relation type
- edata['train_edge_mask']: positive training edge mask
- edata['val_edge_mask']: positive validation edge mask
- edata['test_edge_mask']: positive testing edge mask
- edata['train_mask']: training edge set mask (include reversed training edges)
- edata['val_mask']: validation edge set mask (include reversed validation edges)
- edata['test_mask']: testing edge set mask (include reversed testing edges)
- ndata['ntype']: node type. All 0 in this dataset
:class:`dgl.DGLGraph`
The graph contains
- ``edata['e_type']``: edge relation type
- ``edata['train_edge_mask']``: positive training edge mask
- ``edata['val_edge_mask']``: positive validation edge mask
- ``edata['test_edge_mask']``: positive testing edge mask
- ``edata['train_mask']``: training edge set mask (include reversed training edges)
- ``edata['val_mask']``: validation edge set mask (include reversed validation edges)
- ``edata['test_mask']``: testing edge set mask (include reversed testing edges)
- ``ndata['ntype']``: node type. All 0 in this dataset
"""
return super(FB15kDataset, self).__getitem__(idx)
......@@ -581,21 +607,27 @@ class WN18Dataset(KnowledgeGraphDataset):
r""" WN18 link prediction dataset.
.. deprecated:: 0.5.0
`train` is deprecated, it is replaced by:
- ``train`` is deprecated, it is replaced by:
>>> dataset = WN18Dataset()
>>> graph = dataset[0]
>>> train_mask = graph.edata['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
>>> src, dst = graph.edges(train_idx)
>>> rel = graph.edata['etype'][train_idx]
`valid` is deprecated, it is replaced by:
- ``valid`` is deprecated, it is replaced by:
>>> dataset = WN18Dataset()
>>> graph = dataset[0]
>>> val_mask = graph.edata['val_mask']
>>> val_idx = th.nonzero(val_mask).squeeze()
>>> src, dst = graph.edges(val_idx)
>>> rel = graph.edata['etype'][val_idx]
`test` is deprecated, it is replaced by:
- ``test`` is deprecated, it is replaced by:
>>> dataset = WN18Dataset()
>>> graph = dataset[0]
>>> test_mask = graph.edata['test_mask']
......@@ -603,17 +635,23 @@ class WN18Dataset(KnowledgeGraphDataset):
>>> src, dst = graph.edges(test_idx)
>>> rel = graph.edata['etype'][test_idx]
The WN18 dataset was introduced in http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf,
The WN18 dataset was introduced in `Translating Embeddings for Modeling
Multi-relational Data <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_.
It included the full 18 relations scraped from
WordNet for roughly 41,000 synsets. When creating
the dataset, a reverse edge with reversed relation
types are created for each edge by default.
WN18 dataset tatistics:
Nodes: 40943
Number of relation types: 18
Number of reversed relation types: 18
Label Split: Train: 141442 ,Valid: 5000, Test: 5000
WN18 dataset statistics:
- Nodes: 40943
- Number of relation types: 18
- Number of reversed relation types: 18
- Label Split:
- Train: 141442
- Valid: 5000
- Test: 5000
Parameters
----------
......@@ -633,11 +671,11 @@ class WN18Dataset(KnowledgeGraphDataset):
Number of nodes
num_rels: int
Number of relation types
train: numpy array
train: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the training graph
valid: numpy array
valid: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the validation graph
test: numpy array
test: numpy.ndarray
A numpy array of triplets (src, rel, dst) for the test graph
Examples
......@@ -682,16 +720,18 @@ class WN18Dataset(KnowledgeGraphDataset):
Return
-------
dgl.DGLGraph
The graph contain
- edata['e_type']: edge relation type
- edata['train_edge_mask']: positive training edge mask
- edata['val_edge_mask']: positive validation edge mask
- edata['test_edge_mask']: positive testing edge mask
- edata['train_mask']: training edge set mask (include reversed training edges)
- edata['val_mask']: validation edge set mask (include reversed validation edges)
- edata['test_mask']: testing edge set mask (include reversed testing edges)
- ndata['ntype']: node type. All 0 in this dataset
:class:`dgl.DGLGraph`
The graph contains
- ``edata['e_type']``: edge relation type
- ``edata['train_edge_mask']``: positive training edge mask
- ``edata['val_edge_mask']``: positive validation edge mask
- ``edata['test_edge_mask']``: positive testing edge mask
- ``edata['train_mask']``: training edge set mask (include reversed training edges)
- ``edata['val_mask']``: validation edge set mask (include reversed validation edges)
- ``edata['test_mask']``: testing edge set mask (include reversed testing edges)
- ``ndata['ntype']``: node type. All 0 in this dataset
"""
return super(WN18Dataset, self).__getitem__(idx)
......
......@@ -12,17 +12,18 @@ from ..transform import add_self_loop
__all__ = ['MiniGCDataset']
class MiniGCDataset(DGLDataset):
"""The dataset class.
"""The synthetic graph classification dataset class.
The datset contains 8 different types of graphs.
* class 0 : cycle graph
* class 1 : star graph
* class 2 : wheel graph
* class 3 : lollipop graph
* class 4 : hypercube graph
* class 5 : grid graph
* class 6 : clique graph
* class 7 : circular ladder graph
- class 0 : cycle graph
- class 1 : star graph
- class 2 : wheel graph
- class 3 : lollipop graph
- class 4 : hypercube graph
- class 5 : grid graph
- class 6 : clique graph
- class 7 : circular ladder graph
Parameters
----------
......@@ -50,7 +51,7 @@ class MiniGCDataset(DGLDataset):
--------
>>> data = MiniGCDataset(100, 16, 32, seed=0)
**The dataset instance is an iterable**
The dataset instance is an iterable
>>> len(data)
100
......@@ -62,7 +63,7 @@ class MiniGCDataset(DGLDataset):
>>> label
tensor(5)
**Batch the graphs and labels for mini-batch training**
Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
......@@ -96,13 +97,15 @@ class MiniGCDataset(DGLDataset):
def __getitem__(self, idx):
"""Get the idx-th sample.
Paramters
Parameters
---------
idx : int
The sample index.
Returns
-------
(dgl.Graph, int)
(:class:`dgl.Graph`, Tensor)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
......@@ -143,7 +146,7 @@ class MiniGCDataset(DGLDataset):
self._gen_circular_ladder(self.num_graphs - len(self.graphs))
# preprocess
for i in range(self.num_graphs):
# convert to Graph, and add self loops
# convert to DGLGraph, and add self loops
self.graphs[i] = add_self_loop(dgl_graph(self.graphs[i]))
self.labels = F.tensor(np.array(self.labels).astype(np.int))
......
......@@ -15,13 +15,17 @@ class PPIDataset(DGLBuiltinDataset):
r""" Protein-Protein Interaction dataset for inductive node classification
.. deprecated:: 0.5.0
`lables` is deprecated, it is replaced by:
- ``lables`` is deprecated, it is replaced by:
>>> dataset = PPIDataset()
>>> for g in dataset:
.... labels = g.ndata['label']
....
>>>
`features` is deprecated, it is replaced by:
- ``features`` is deprecated, it is replaced by:
>>> dataset = PPIDataset()
>>> for g in dataset:
.... features = g.ndata['feat']
......@@ -33,12 +37,13 @@ class PPIDataset(DGLBuiltinDataset):
50 features and 121 labels. 20 graphs for training, 2 for validation
and 2 for testing.
Reference: http://snap.stanford.edu/graphsage/
Reference: `<http://snap.stanford.edu/graphsage/>`_
Statistics:
PPI dataset statistics:
Train examples: 20
Valid examples: 2
Test examples: 2
- Train examples: 20
- Valid examples: 2
- Test examples: 2
Parameters
----------
......@@ -167,10 +172,11 @@ class PPIDataset(DGLBuiltinDataset):
Returns
-------
dgl.DGLGraph
:class:`dgl.DGLGraph`
graph structure, node features and node labels.
- ndata['feat']: node features
- ndata['label']: nodel labels
- ``ndata['feat']``: node features
- ``ndata['label']``: node labels
"""
return self.graphs[item]
......
......@@ -16,15 +16,16 @@ class QM7bDataset(DGLDataset):
This dataset consists of 7,211 molecules with 14 regression targets.
Nodes means atoms and edges means bonds. Edge data 'h' means
the entry of Coulomb matrix.
Reference: http://quantum-machine.org/datasets/
Statistics
----------
Number of graphs: 7,211
Number of regression targets: 14
Average number of nodes: 15
Average number of edges: 245
Edge feature size: 1
Reference: `<http://quantum-machine.org/datasets/>`_
Statistics:
- Number of graphs: 7,211
- Number of regression targets: 14
- Average number of nodes: 15
- Average number of edges: 245
- Edge feature size: 1
Parameters
----------
......@@ -73,10 +74,6 @@ class QM7bDataset(DGLDataset):
def process(self):
mat_path = self.raw_path + '.mat'
if not check_sha1(mat_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name))
self.graphs, self.label = self._load_graph(mat_path)
def _load_graph(self, filename):
......@@ -110,6 +107,10 @@ class QM7bDataset(DGLDataset):
def download(self):
file_path = os.path.join(self.raw_dir, self.name + '.mat')
download(self.url, path=file_path)
if not check_sha1(file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name))
@property
def num_labels(self):
......@@ -126,12 +127,17 @@ class QM7bDataset(DGLDataset):
Returns
-------
(dgl.DGLGraph, Tensor)
(:class:`dgl.DGLGraph`, Tensor)
"""
return self.graphs[idx], self.label[idx]
def __len__(self):
r"""Number of graphs in the dataset"""
r"""Number of graphs in the dataset.
Return
-------
int
"""
return len(self.graphs)
......
......@@ -8,7 +8,10 @@ from collections import OrderedDict
import itertools
import abc
import re
import rdflib as rdf
try:
import rdflib as rdf
except ImportError:
pass
import networkx as nx
import numpy as np
......@@ -112,6 +115,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
raw_dir=None,
force_reload=False,
verbose=True):
import rdflib as rdf
self._insert_reverse = insert_reverse
self._print_every = print_every
self._predict_category = predict_category
......@@ -542,15 +546,21 @@ class AIFBDataset(RDFGraphDataset):
r"""AIFB dataset for node classification task
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = AIFBDataset()
>>> graph = dataset[0]
`train_idx` is deprecated, it can be replaced by:
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = AIFBDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
`test_idx` is deprecated, it can be replaced by:
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = AIFBDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
......@@ -561,11 +571,15 @@ class AIFBDataset(RDFGraphDataset):
University of Karlsruhe.
AIFB dataset statistics:
Nodes: 7262
Edges: 48810 (including reverse edges)
Target Category: Personen
Number of Classes: 4
Label Split: Train: 140, Test: 36
- Nodes: 7262
- Edges: 48810 (including reverse edges)
- Target Category: Personen
- Number of Classes: 4
- Label Split:
- Train: 140
- Test: 36
Parameters
-----------
......@@ -589,7 +603,7 @@ class AIFBDataset(RDFGraphDataset):
The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : dgl.DGLGraph
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
......@@ -608,8 +622,6 @@ class AIFBDataset(RDFGraphDataset):
>>> labels = g.nodes[category].data.pop('labels')
"""
employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
entity_prefix = 'http://www.aifb.uni-karlsruhe.de/'
relation_prefix = 'http://swrc.ontoware.org/'
......@@ -619,6 +631,9 @@ class AIFBDataset(RDFGraphDataset):
raw_dir=None,
force_reload=False,
verbose=True):
import rdflib as rdf
self.employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
self.affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
url = _get_dgl_url('dataset/rdf/aifb-hetero.zip')
name = 'aifb-hetero'
predict_category = 'Personen'
......@@ -639,16 +654,23 @@ class AIFBDataset(RDFGraphDataset):
Return
-------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
:class:`dgl.DGLGraph`
The graph contains:
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
"""
return super(AIFBDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(AIFBDataset, self).__len__()
def parse_entity(self, term):
......@@ -703,26 +725,36 @@ class MUTAGDataset(RDFGraphDataset):
r"""MUTAG dataset for node classification task
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = MUTAGDataset()
>>> graph = dataset[0]
`train_idx` is deprecated, it can be replaced by:
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = MUTAGDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
`test_idx` is deprecated, it can be replaced by:
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = MUTAGDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze()
Mutag dataset statistics:
Nodes: 27163
Edges: 148100 (including reverse edges)
Target Category: d
Number of Classes: 2
Label Split: Train: 272, Test: 68
- Nodes: 27163
- Edges: 148100 (including reverse edges)
- Target Category: d
- Number of Classes: 2
- Label Split:
- Train: 272
- Test: 68
Parameters
-----------
......@@ -746,7 +778,7 @@ class MUTAGDataset(RDFGraphDataset):
The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : dgl.DGLGraph
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
......@@ -768,11 +800,6 @@ class MUTAGDataset(RDFGraphDataset):
d_entity = re.compile("d[0-9]")
bond_entity = re.compile("bond[0-9]")
is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
rdf_subclassof = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#subClassOf")
rdf_domain = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#domain")
entity_prefix = 'http://dl-learner.org/carcinogenesis#'
relation_prefix = entity_prefix
......@@ -782,6 +809,12 @@ class MUTAGDataset(RDFGraphDataset):
raw_dir=None,
force_reload=False,
verbose=True):
import rdflib as rdf
self.is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
self.rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
self.rdf_subclassof = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#subClassOf")
self.rdf_domain = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#domain")
url = _get_dgl_url('dataset/rdf/mutag-hetero.zip')
name = 'mutag-hetero'
predict_category = 'd'
......@@ -802,16 +835,23 @@ class MUTAGDataset(RDFGraphDataset):
Return
-------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
:class:`dgl.DGLGraph`
The graph contains:
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
"""
return super(MUTAGDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(MUTAGDataset, self).__len__()
def parse_entity(self, term):
......@@ -882,32 +922,42 @@ class BGSDataset(RDFGraphDataset):
r"""BGS dataset for node classification task
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = BGSDataset()
>>> graph = dataset[0]
`train_idx` is deprecated, it can be replaced by:
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = BGSDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
`test_idx` is deprecated, it can be replaced by:
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = BGSDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze()
BGS namespace convention:
http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE
``http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE``.
We ignored all literal nodes and the relations connecting them in the
output graph. We also ignored the relation used to mark whether a
term is CURRENT or DEPRECATED.
BGS dataset statistics:
Nodes: 94806
Edges: 672884 (including reverse edges)
Target Category: Lexicon/NamedRockUnit
Number of Classes: 2
Label Split: Train: 117, Test: 29
- Nodes: 94806
- Edges: 672884 (including reverse edges)
- Target Category: Lexicon/NamedRockUnit
- Number of Classes: 2
- Label Split:
- Train: 117
- Test: 29
Parameters
-----------
......@@ -931,7 +981,7 @@ class BGSDataset(RDFGraphDataset):
The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : dgl.DGLGraph
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
......@@ -950,7 +1000,6 @@ class BGSDataset(RDFGraphDataset):
>>> labels = g.nodes[category].data.pop('labels')
"""
lith = rdf.term.URIRef("http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis")
entity_prefix = 'http://data.bgs.ac.uk/'
status_prefix = 'http://data.bgs.ac.uk/ref/CurrentStatus'
relation_prefix = 'http://data.bgs.ac.uk/ref'
......@@ -961,9 +1010,11 @@ class BGSDataset(RDFGraphDataset):
raw_dir=None,
force_reload=False,
verbose=True):
import rdflib as rdf
url = _get_dgl_url('dataset/rdf/bgs-hetero.zip')
name = 'bgs-hetero'
predict_category = 'Lexicon/NamedRockUnit'
self.lith = rdf.term.URIRef("http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis")
super(BGSDataset, self).__init__(name, url, predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
......@@ -981,16 +1032,23 @@ class BGSDataset(RDFGraphDataset):
Return
-------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
:class:`dgl.DGLGraph`
The graph contains:
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
"""
return super(BGSDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(BGSDataset, self).__len__()
def parse_entity(self, term):
......@@ -1057,32 +1115,44 @@ class AMDataset(RDFGraphDataset):
"""AM dataset. for node classification task
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = AMDataset()
>>> graph = dataset[0]
`train_idx` is deprecated, it can be replaced by:
- ``train_idx`` is deprecated, it can be replaced by:
>>> dataset = AMDataset()
>>> graph = dataset[0]
>>> train_mask = graph.nodes[dataset.category].data['train_mask']
>>> train_idx = th.nonzero(train_mask).squeeze()
`test_idx` is deprecated, it can be replaced by:
- ``test_idx`` is deprecated, it can be replaced by:
>>> dataset = AMDataset()
>>> graph = dataset[0]
>>> test_mask = graph.nodes[dataset.category].data['test_mask']
>>> test_idx = th.nonzero(test_mask).squeeze()
Namespace convention:
Instance: http://purl.org/collections/nl/am/<type>-<id>
Relation: http://purl.org/collections/nl/am/<name>
- Instance: ``http://purl.org/collections/nl/am/<type>-<id>``
- Relation: ``http://purl.org/collections/nl/am/<name>``
We ignored all literal nodes and the relations connecting them in the
output graph.
AM dataset statistics:
Nodes: 881680
Edges: 5668682 (including reverse edges)
Target Category: proxy
Number of Classes: 11
Label Split: Train: 802, Test: 198
- Nodes: 881680
- Edges: 5668682 (including reverse edges)
- Target Category: proxy
- Number of Classes: 11
- Label Split:
- Train: 802
- Test: 198
Parameters
-----------
......@@ -1106,7 +1176,7 @@ class AMDataset(RDFGraphDataset):
The entity category (node type) that has labels for prediction
labels : Tensor
All the labels of the entities in ``predict_category``
graph : dgl.DGLGraph
graph : :class:`dgl.DGLGraph`
Graph structure
train_idx : Tensor
Entity IDs for training. All IDs are local IDs w.r.t. to ``predict_category``.
......@@ -1125,8 +1195,6 @@ class AMDataset(RDFGraphDataset):
>>> labels = g.nodes[category].data.pop('labels')
"""
objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
entity_prefix = 'http://purl.org/collections/nl/am/'
relation_prefix = entity_prefix
......@@ -1136,6 +1204,9 @@ class AMDataset(RDFGraphDataset):
raw_dir=None,
force_reload=False,
verbose=True):
import rdflib as rdf
self.objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
self.material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
url = _get_dgl_url('dataset/rdf/am-hetero.zip')
name = 'am-hetero'
predict_category = 'proxy'
......@@ -1156,16 +1227,23 @@ class AMDataset(RDFGraphDataset):
Return
-------
dgl.DGLGraph
graph structure, node features and labels.
- ndata['train_mask']: mask for training node set
- ndata['test_mask']: mask for testing node set
- ndata['labels']: mask for labels
:class:`dgl.DGLGraph`
The graph contains:
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['labels']``: mask for labels
"""
return super(AMDataset, self).__getitem__(idx)
def __len__(self):
r"""The number of graphs in the dataset."""
r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(AMDataset, self).__len__()
def parse_entity(self, term):
......
......@@ -15,29 +15,43 @@ class RedditDataset(DGLBuiltinDataset):
r""" Reddit dataset for community detection (node classification)
.. deprecated:: 0.5.0
`graph` is deprecated, it is replaced by:
- ``graph`` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
`num_labels` is deprecated, it is replaced by:
- ``num_labels`` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> num_classes = dataset.num_classes
`train_mask` is deprecated, it is replaced by:
- ``train_mask`` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
>>> train_mask = graph.ndata['train_mask']
`val_mask` is deprecated, it is replaced by:
- ``val_mask`` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
>>> val_mask = graph.ndata['val_mask']
`test_mask` is deprecated, it is replaced by:
- ``test_mask`` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
>>> test_mask = graph.ndata['test_mask']
`features` is deprecated, it is replaced by:
- ``features`` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
>>> features = graph.ndata['feat']
`labels` is deprecated, it is replaced by:
- ``labels`` is deprecated, it is replaced by:
>>> dataset = RedditDataset()
>>> graph = dataset[0]
>>> labels = graph.ndata['label']
......@@ -49,16 +63,16 @@ class RedditDataset(DGLBuiltinDataset):
posts with an average degree of 492. We use the first 20 days for training and the
remaining days for testing (with 30% used for validation).
Reference: http://snap.stanford.edu/graphsage/
Reference: `<http://snap.stanford.edu/graphsage/>`_
Statistics
----------
Nodes: 232,965
Edges: 114,615,892
Node feature size: 602
Number of training samples: 153,431
Number of validation samples: 23,831
Number of test samples: 55,703
- Nodes: 232,965
- Edges: 114,615,892
- Node feature size: 602
- Number of training samples: 153,431
- Number of validation samples: 23,831
- Number of test samples: 55,703
Parameters
----------
......@@ -76,7 +90,7 @@ class RedditDataset(DGLBuiltinDataset):
----------
num_classes : int
Number of classes for each node
graph : dgl.DGLGraph
graph : :class:`dgl.DGLGraph`
Graph of the dataset
num_labels : int
Number of classes for each node
......@@ -220,13 +234,14 @@ class RedditDataset(DGLBuiltinDataset):
Returns
-------
dgl.DGLGraph
graph structure, node labels, node features and splitting masks
- ndata['label']: node label
- ndata['feat']: node feature
- ndata['train_mask']: mask for training node set
- ndata['val_mask']: mask for validation node set
- ndata['test_mask']: mask for test node set
:class:`dgl.DGLGraph`
graph structure, node labels, node features and splitting masks:
- ``ndata['label']``: node label
- ``ndata['feat']``: node feature
- ``ndata['train_mask']``: mask for training node set
- ``ndata['val_mask']``: mask for validation node set
- ``ndata['test_mask']:`` mask for test node set
"""
assert idx == 0, "Reddit Dataset only has one graph"
return self._graph
......
......@@ -59,8 +59,7 @@ def sbm(n_blocks, block_size, p, q, rng=None):
class SBMMixtureDataset(DGLDataset):
r""" Symmetric Stochastic Block Model Mixture
Reference: Appendix C of "Supervised Community Detection with Hierarchical
Graph Neural Networks" (https://arxiv.org/abs/1705.08415).
Reference: Appendix C of `Supervised Community Detection with Hierarchical Graph Neural Networks <https://arxiv.org/abs/1705.08415>`_
Parameters
----------
......@@ -175,15 +174,15 @@ class SBMMixtureDataset(DGLDataset):
Returns
-------
graph : dgl.DGLGraph
graph: :class:`dgl.DGLGraph`
The original graph
line_graph : dgl.DGLGraph
line_graph: :class:`dgl.DGLGraph`
The line graph of `graph`
graph_degree : numpy.ndarray
graph_degree: numpy.ndarray
In degrees for each node in `graph`
line_graph_degree : numpy.ndarray
line_graph_degree: numpy.ndarray
In degrees for each node in `line_graph`
pm_pd : numpy.ndarray
pm_pd: numpy.ndarray
Edge indicator matrices Pm and Pd
"""
return self._graphs[idx], self._line_graphs[idx], \
......@@ -203,29 +202,30 @@ class SBMMixtureDataset(DGLDataset):
Parameters
----------
x : tuple
a batch of data that contains
graph : dgl.DGLGraph
a batch of data that contains:
- graph: :class:`dgl.DGLGraph`
The original graph
line_graph : dgl.DGLGraph
- line_graph: :class:`dgl.DGLGraph`
The line graph of `graph`
graph_degree : numpy.ndarray
- graph_degree: numpy.ndarray
In degrees for each node in `graph`
line_graph_degree : numpy.ndarray
- line_graph_degree: numpy.ndarray
In degrees for each node in `line_graph`
pm_pd : numpy.ndarray
- pm_pd: numpy.ndarray
Edge indicator matrices Pm and Pd
Returns
-------
g_batch : dgl.DGLGraph
g_batch: :class:`dgl.DGLGraph`
Batched graphs
lg_batch : dgl.DGLGraph
lg_batch: :class:`dgl.DGLGraph`
Batched line graphs
degg_batch : numpy.ndarray
degg_batch: numpy.ndarray
A batch of in degrees for each node in `g_batch`
deglg_batch : numpy.ndarray
deglg_batch: numpy.ndarray
A batch of in degrees for each node in `lg_batch`
pm_pd_batch : numpy.ndarray
pm_pd_batch: numpy.ndarray
A batch of edge indicator matrices Pm and Pd
"""
g, lg, deg_g, deg_lg, pm_pd = zip(*x)
......
......@@ -23,13 +23,14 @@ class SSTDataset(DGLBuiltinDataset):
r"""Stanford Sentiment Treebank dataset.
.. deprecated:: 0.5.0
`trees` is deprecated, it is replaced by:
- ``trees`` is deprecated, it is replaced by:
>>> dataset = SSTDataset()
>>> for tree in dataset:
.... # your code here
....
>>>
`num_vocabs` is deprecated, it is replaced by `vocab_size`
- ``num_vocabs`` is deprecated, it is replaced by ``vocab_size``.
Each sample is the constituency tree of a sentence. The leaf nodes
represent words. The word is a int value stored in the ``x`` feature field.
......@@ -37,14 +38,14 @@ class SSTDataset(DGLBuiltinDataset):
Each node also has a sentiment annotation: 5 classes (very negative,
negative, neutral, positive and very positive). The sentiment label is a
int value stored in the ``y`` feature field.
Official site: http://nlp.stanford.edu/sentiment/index.html
Official site: `<http://nlp.stanford.edu/sentiment/index.html>`_
Statistics
----------
Train examples: 8,544
Dev examples: 1,101
Test examples: 2,210
Number of classes for each node: 5
Statistics:
- Train examples: 8,544
- Dev examples: 1,101
- Test examples: 2,210
- Number of classes for each node: 5
Parameters
----------
......@@ -100,16 +101,14 @@ class SSTDataset(DGLBuiltinDataset):
>>> train_data.vocab_size
19536
>>> train_data[0]
DGLGraph(num_nodes=71, num_edges=70,
ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(),
dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={})
Graph(num_nodes=71, num_edges=70,
ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={})
>>> for tree in train_data:
... input_ids = tree.ndata['x']
... labels = tree.ndata['y']
... mask = tree.ndata['mask']
... # your code here
>>>
"""
PAD_WORD = -1 # special pad word id
......@@ -247,11 +246,13 @@ class SSTDataset(DGLBuiltinDataset):
Returns
-------
dgl.DGLGraph
graph structure, word id for each node, node labels and masks
- ndata['x']: word id of the node
- ndata['y']: label of the node
- ndata['mask']: 1 if the node is a leaf, otherwise 0
:class:`dgl.DGLGraph`
graph structure, word id for each node, node labels and masks.
- ``ndata['x']``: word id of the node
- ``ndata['y']:`` label of the node
- ``ndata['mask']``: 1 if the node is a leaf, otherwise 0
"""
return self._trees[idx]
......
......@@ -16,17 +16,18 @@ class LegacyTUDataset(DGLBuiltinDataset):
Parameters
----------
name : str
Dataset Name, such as `ENZYMES`, `DD`, `COLLAB`
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
use_pandas : bool
Numpy's file read function has performance issue when file is large,
using pandas can be faster.
Default: False
hidden_size : int
Some dataset doesn't contain features.
Use constant node features initialization instead, with hidden size as `hidden_size`.
Use constant node features initialization instead, with hidden size as ``hidden_size``.
Default : 10
max_allow_node : int
Remove graphs that contains more nodes than `max_allow_node`.
Remove graphs that contains more nodes than ``max_allow_node``.
Default : None
Attributes
......@@ -40,7 +41,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
--------
>>> data = LegacyTUDataset('DD')
**The dataset instance is an iterable**
The dataset instance is an iterable
>>> len(data)
1178
......@@ -52,7 +53,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
>>> label
tensor(1)
**Batch the graphs and labels for mini-batch training*
Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
......@@ -190,20 +191,23 @@ class LegacyTUDataset(DGLBuiltinDataset):
def __getitem__(self, idx):
"""Get the idx-th sample.
Paramters
Parameters
---------
idx : int
The sample index.
Returns
-------
(dgl.Graph, int)
Graph with node feature stored in `feat` field and node label in `node_label` if available.
(:class:`dgl.DGLGraph`, Tensor)
Graph with node feature stored in ``feat`` field and node label in ``node_label`` if available.
And its label.
"""
g = self.graph_lists[idx]
return g, self.graph_labels[idx]
def __len__(self):
"""Return the number of graphs in the dataset."""
return len(self.graph_lists)
def _file_path(self, category):
......@@ -234,8 +238,8 @@ class TUDataset(DGLBuiltinDataset):
Parameters
----------
name : str
Dataset Name, such as `ENZYMES`, `DD`, `COLLAB`, `MUTAG`, can be the
datasets name on https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets.
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
Attributes
----------
......@@ -248,7 +252,7 @@ class TUDataset(DGLBuiltinDataset):
--------
>>> data = TUDataset('DD')
**The dataset instance is an iterable**
The dataset instance is an iterable
>>> len(data)
188
......@@ -260,7 +264,7 @@ class TUDataset(DGLBuiltinDataset):
>>> label
tensor([1])
**Batch the graphs and labels for mini-batch training*
Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
......@@ -355,20 +359,23 @@ class TUDataset(DGLBuiltinDataset):
def __getitem__(self, idx):
"""Get the idx-th sample.
Paramters
Parameters
---------
idx : int
The sample index.
Returns
-------
(dgl.Graph, int)
Graph with node feature stored in `feat` field and node label in `node_label` if available.
(:class:`dgl.DGLGraph`, Tensor)
Graph with node feature stored in ``feat`` field and node label in ``node_label`` if available.
And its label.
"""
g = self.graph_lists[idx]
return g, self.graph_labels[idx]
def __len__(self):
"""Return the number of graphs in the dataset."""
return len(self.graph_lists)
def _file_path(self, category):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment