Unverified Commit 690f37bb authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files
parent 89b321b8
...@@ -57,6 +57,7 @@ Datasets for node classification/regression tasks ...@@ -57,6 +57,7 @@ Datasets for node classification/regression tasks
PATTERNDataset PATTERNDataset
CLUSTERDataset CLUSTERDataset
ChameleonDataset ChameleonDataset
SquirrelDataset
Edge Prediction Datasets Edge Prediction Datasets
--------------------------------------- ---------------------------------------
......
...@@ -54,7 +54,7 @@ from .tu import LegacyTUDataset, TUDataset ...@@ -54,7 +54,7 @@ from .tu import LegacyTUDataset, TUDataset
from .utils import * from .utils import *
from .cluster import CLUSTERDataset from .cluster import CLUSTERDataset
from .pattern import PATTERNDataset from .pattern import PATTERNDataset
from .wiki_network import ChameleonDataset from .wiki_network import ChameleonDataset, SquirrelDataset
from .wikics import WikiCSDataset from .wikics import WikiCSDataset
from .yelp import YelpDataset from .yelp import YelpDataset
from .zinc import ZINCDataset from .zinc import ZINCDataset
......
""" """
Wikipedia page-page networks on the chameleon topic. Wikipedia page-page networks on two topics: chameleons and squirrels.
""" """
import os import os
...@@ -23,8 +23,7 @@ class WikiNetworkDataset(DGLBuiltinDataset): ...@@ -23,8 +23,7 @@ class WikiNetworkDataset(DGLBuiltinDataset):
raw_dir : str raw_dir : str
Raw file directory to store the processed data. Raw file directory to store the processed data.
force_reload : bool force_reload : bool
Whether to always generate the data from scratch rather than load a Whether to re-download the data source.
cached version.
verbose : bool verbose : bool
Whether to print progress information. Whether to print progress information.
transform : callable transform : callable
...@@ -123,7 +122,7 @@ class ChameleonDataset(WikiNetworkDataset): ...@@ -123,7 +122,7 @@ class ChameleonDataset(WikiNetworkDataset):
- Nodes: 2277 - Nodes: 2277
- Edges: 36101 - Edges: 36101
- Number of Classes: 5 - Number of Classes: 5
- 10 splits with 60/20/20 train/val/test ratio - 10 train/val/test splits
- Train: 1092 - Train: 1092
- Val: 729 - Val: 729
...@@ -134,8 +133,7 @@ class ChameleonDataset(WikiNetworkDataset): ...@@ -134,8 +133,7 @@ class ChameleonDataset(WikiNetworkDataset):
raw_dir : str, optional raw_dir : str, optional
Raw file directory to store the processed data. Default: ~/.dgl/ Raw file directory to store the processed data. Default: ~/.dgl/
force_reload : bool, optional force_reload : bool, optional
Whether to always generate the data from scratch rather than load a Whether to re-download the data source. Default: False
cached version. Default: False
verbose : bool, optional verbose : bool, optional
Whether to print progress information. Default: True Whether to print progress information. Default: True
transform : callable, optional transform : callable, optional
...@@ -182,3 +180,79 @@ class ChameleonDataset(WikiNetworkDataset): ...@@ -182,3 +180,79 @@ class ChameleonDataset(WikiNetworkDataset):
verbose=verbose, verbose=verbose,
transform=transform, transform=transform,
) )
class SquirrelDataset(WikiNetworkDataset):
r"""Wikipedia page-page network on squirrels from `Multi-scale Attributed
Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by
`Geom-GCN: Geometric Graph Convolutional Networks
<https://arxiv.org/abs/2002.05287>`
Nodes represent articles from the English Wikipedia, edges reflect mutual
links between them. Node features indicate the presence of particular nouns
in the articles. The nodes were classified into 5 classes in terms of their
average monthly traffic.
Statistics:
- Nodes: 5201
- Edges: 217073
- Number of Classes: 5
- 10 train/val/test splits
- Train: 2496
- Val: 1664
- Test: 1041
Parameters
----------
raw_dir : str, optional
Raw file directory to store the processed data. Default: ~/.dgl/
force_reload : bool, optional
Whether to re-download the data source. Default: False
verbose : bool, optional
Whether to print progress information. Default: True
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access. Default: None
Attributes
----------
num_classes : int
Number of node classes
Notes
-----
The graph does not come with edges for both directions.
Examples
--------
>>> from dgl.data import SquirrelDataset
>>> dataset = SquirrelDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get data split
>>> train_mask = g.ndata["train_mask"]
>>> val_mask = g.ndata["val_mask"]
>>> test_mask = g.ndata["test_mask"]
>>> # get labels
>>> label = g.ndata['label']
"""
def __init__(
self, raw_dir=None, force_reload=False, verbose=True, transform=None
):
super(SquirrelDataset, self).__init__(
name="squirrel",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
...@@ -15,9 +15,25 @@ import dgl ...@@ -15,9 +15,25 @@ import dgl
def test_chameleon(): def test_chameleon():
transform = dgl.AddSelfLoop(allow_duplicate=True) transform = dgl.AddSelfLoop(allow_duplicate=True)
# chameleon
g = dgl.data.ChameleonDataset(force_reload=True)[0] g = dgl.data.ChameleonDataset(force_reload=True)[0]
assert g.num_nodes() == 2277 assert g.num_nodes() == 2277
assert g.num_edges() == 36101 assert g.num_edges() == 36101
g2 = dgl.data.ChameleonDataset(force_reload=True, transform=transform)[0] g2 = dgl.data.ChameleonDataset(force_reload=True, transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes() assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_squirrel():
transform = dgl.AddSelfLoop(allow_duplicate=True)
g = dgl.data.SquirrelDataset(force_reload=True)[0]
assert g.num_nodes() == 5201
assert g.num_edges() == 217073
g2 = dgl.data.SquirrelDataset(force_reload=True, transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment