"...csrc/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a8ebd0b3e69407e06ad4e8a4f2ade5147dda6628"
Unverified Commit 17829024 authored by Guangyu Zhou's avatar Guangyu Zhou Committed by GitHub
Browse files

[Dataset] Add PATTERN dataset (#5422)



* add PATTERN dataset

* fix bug

* fix bugs

* fix issues

* refine according to dongyu's comments

---------
Co-authored-by: default avatarBuptTab <gyzhou2000@gmail.com>
Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent a454734f
...@@ -54,6 +54,7 @@ Datasets for node classification/regression tasks ...@@ -54,6 +54,7 @@ Datasets for node classification/regression tasks
WikiCSDataset WikiCSDataset
FlickrDataset FlickrDataset
YelpDataset YelpDataset
PATTERNDataset
CLUSTERDataset CLUSTERDataset
Edge Prediction Datasets Edge Prediction Datasets
......
...@@ -53,6 +53,7 @@ from .tree import SST, SSTDataset ...@@ -53,6 +53,7 @@ from .tree import SST, SSTDataset
from .tu import LegacyTUDataset, TUDataset from .tu import LegacyTUDataset, TUDataset
from .utils import * from .utils import *
from .cluster import CLUSTERDataset from .cluster import CLUSTERDataset
from .pattern import PATTERNDataset
from .wikics import WikiCSDataset from .wikics import WikiCSDataset
from .yelp import YelpDataset from .yelp import YelpDataset
......
""" PATTERNDataset for inductive learning. """
import os
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs
class PATTERNDataset(DGLBuiltinDataset):
r"""PATTERN dataset for graph pattern recognition task.
Each graph G contains 5 communities with sizes randomly selected between [5, 35].
The SBM of each community is p = 0.5, q = 0.35, and the node features on G are
generated with a uniform random distribution with a vocabulary of size 3, i.e. {0, 1, 2}.
Then randomly generate 100 patterns P composed of 20 nodes with intra-probability :math:`p_P` = 0.5
and extra-probability :math:`q_P` = 0.5 (i.e. 50% of nodes in P are connected to G). The node features
for P are also generated as a random signal with values {0, 1, 2}. The graphs are of sizes
44-188 nodes. The output node labels have value 1 if the node belongs to P and value 0 if it is in G.
Reference `<https://arxiv.org/pdf/2003.00982.pdf>`_
Statistics:
- Train examples: 10,000
- Valid examples: 2,000
- Test examples: 2,000
- Number of classes for each node: 2
Parameters
----------
mode : str
Must be one of ('train', 'valid', 'test').
Default: 'train'
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset.
Default: False
verbose : bool
Whether to print out progress information.
Default: False
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_classes : int
Number of classes for each node.
Examples
—-------
>>> from dgl.data import PATTERNDataset
>>> data = PATTERNDataset(mode='train')
>>> data.num_classes
2
>>> len(trainset)
10000
>>> data[0]
Graph(num_nodes=108, num_edges=4884, ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int16)}
edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
"""
def __init__(
self,
mode="train",
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
assert mode in ["train", "valid", "test"]
self.mode = mode
_url = _get_dgl_url("dataset/SBM_PATTERN.zip")
super(PATTERNDataset, self).__init__(
name="pattern",
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
self.load()
def has_cache(self):
graph_path = os.path.join(
self.save_path, "SBM_PATTERN_{}.bin".format(self.mode)
)
return os.path.exists(graph_path)
def load(self):
graph_path = os.path.join(
self.save_path, "SBM_PATTERN_{}.bin".format(self.mode)
)
self._graphs, _ = load_graphs(graph_path)
@property
def num_classes(self):
r"""Number of classes for each node."""
return 2
def __len__(self):
r"""The number of examples in the dataset."""
return len(self._graphs)
def __getitem__(self, idx):
r"""Get the idx^th sample.
Parameters
---------
idx : int
The sample index.
Returns
-------
:class:`dgl.DGLGraph`
graph structure, node features, node labels and edge features.
- ``ndata['feat']``: node features
- ``ndata['label']``: node labels
- ``edata['feat']``: edge features
"""
if self._transform is None:
return self._graphs[idx]
else:
return self._transform(self._graphs[idx])
...@@ -364,6 +364,28 @@ def test_flickr(): ...@@ -364,6 +364,28 @@ def test_flickr():
assert g2.num_edges() - g.num_edges() == g.num_nodes() assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_pattern():
mode_n_graphs = {
"train": 10000,
"valid": 2000,
"test": 2000,
}
transform = dgl.AddSelfLoop(allow_duplicate=True)
for mode, n_graphs in mode_n_graphs.items():
ds = data.PATTERNDataset(mode=mode)
assert len(ds) == n_graphs, (len(ds), mode)
g1 = ds[0]
ds = data.PATTERNDataset(mode=mode, transform=transform)
g2 = ds[0]
assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
assert ds.num_classes == 2
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.", reason="Datasets don't need to be tested on GPU.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment