Unverified Commit 4534650f authored by ZhenyuLU_Heliodore's avatar ZhenyuLU_Heliodore Committed by GitHub
Browse files

[Dataset] Add ZINC Dataset (#5428)



* Update dgl.data.rst

* Add files via upload

* Add files via upload

* Add files via upload

* Update zinc.py

* Update dgl.data.rst

* Update test_data.py

* Add files via upload

* Update cluster.py

* Update pattern.py

* Update zinc.py

* Update zinc.py

* Update test_data.py

* lint

* fix

* fix path

* update test on label shape

---------
Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent acb955e1
......@@ -93,6 +93,7 @@ Datasets for graph classification/regression tasks
GINDataset
FakeNewsDataset
BA2MotifDataset
ZINCDataset
Dataset adapters
-------------------
......
......@@ -56,6 +56,7 @@ from .cluster import CLUSTERDataset
from .pattern import PATTERNDataset
from .wikics import WikiCSDataset
from .yelp import YelpDataset
from .zinc import ZINCDataset
def register_data_args(parser):
......
......@@ -49,7 +49,7 @@ class CLUSTERDataset(DGLBuiltinDataset):
Number of classes for each node.
Examples
-------
--------
>>> from dgl.data import CLUSTERDataset
>>>
>>> trainset = CLUSTERDataset(mode='train')
......
......@@ -50,7 +50,7 @@ class PATTERNDataset(DGLBuiltinDataset):
Number of classes for each node.
Examples
-------
--------
>>> from dgl.data import PATTERNDataset
>>> data = PATTERNDataset(mode='train')
>>> data.num_classes
......
import os
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs
class ZINCDataset(DGLBuiltinDataset):
r"""ZINC dataset for the graph regression task.
A subset (12K) of ZINC molecular graphs (250K) dataset is used to
regress a molecular property known as the constrained solubility.
For each molecular graph, the node features are the types of heavy
atoms, between which the edge features are the types of bonds.
Each graph contains 9-37 nodes and 16-84 edges.
Reference `<https://arxiv.org/pdf/2003.00982.pdf>`_
Statistics:
Train examples: 10,000
Valid examples: 1,000
Test examples: 1,000
Average number of nodes: 23.16
Average number of edges: 39.83
Number of atom types: 28
Number of bond types: 4
Parameters
----------
mode : str, optional
Should be chosen from ["train", "valid", "test"]
Default: "train".
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: "~/.dgl/".
force_reload : bool
Whether to reload the dataset.
Default: False.
verbose : bool
Whether to print out progress information.
Default: False.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_atom_types : int
Number of atom types.
num_bond_types : int
Number of bond types.
Examples
---------
>>> from dgl.data import ZINCDataset
>>> training_set = ZINCDataset(mode="train")
>>> training_set.num_atom_types
28
>>> len(training_set)
10000
>>> graph, label = training_set[0]
>>> graph
Graph(num_nodes=29, num_edges=64,
ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)})
"""
def __init__(
self,
mode="train",
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
self._url = _get_dgl_url("dataset/ZINC12k.zip")
self.mode = mode
super(ZINCDataset, self).__init__(
name="zinc",
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
self.load()
def has_cache(self):
graph_path = os.path.join(
self.save_path, "ZincDGL_{}.bin".format(self.mode)
)
return os.path.exists(graph_path)
def load(self):
graph_path = os.path.join(
self.save_path, "ZincDGL_{}.bin".format(self.mode)
)
self._graphs, self._labels = load_graphs(graph_path)
@property
def num_atom_types(self):
return 28
@property
def num_bond_types(self):
return 4
def __len__(self):
return len(self._graphs)
def __getitem__(self, idx):
r"""Get one example by index.
Parameters
----------
idx : int
The sample index.
Returns
-------
dgl.DGLGraph
Each graph contains:
- ``ndata['feat']``: Types of heavy atoms as node features
- ``edata['feat']``: Types of bonds as edge features
Tensor
Constrained solubility as graph label
"""
labels = self._labels["g_label"]
if self._transform is None:
return self._graphs[idx], labels[idx]
else:
return self._transform(self._graphs[idx]), labels[idx]
......@@ -408,6 +408,63 @@ def test_cluster():
assert ds.num_classes == 6
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_zinc():
mode_n_graphs = {
"train": 10000,
"valid": 1000,
"test": 1000,
}
transform = dgl.AddSelfLoop(allow_duplicate=True)
for mode, n_graphs in mode_n_graphs.items():
dataset1 = data.ZINCDataset(mode=mode)
g1, label = dataset1[0]
dataset2 = data.ZINCDataset(mode=mode, transform=transform)
g2, _ = dataset2[0]
assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
# return a scalar tensor
assert not label.shape
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_extract_archive():
# gzip
with tempfile.TemporaryDirectory() as src_dir:
gz_file = "gz_archive"
gz_path = os.path.join(src_dir, gz_file + ".gz")
content = b"test extract archive gzip"
with gzip.open(gz_path, "wb") as f:
f.write(content)
with tempfile.TemporaryDirectory() as dst_dir:
data.utils.extract_archive(gz_path, dst_dir, overwrite=True)
assert os.path.exists(os.path.join(dst_dir, gz_file))
# tar
with tempfile.TemporaryDirectory() as src_dir:
tar_file = "tar_archive"
tar_path = os.path.join(src_dir, tar_file + ".tar")
# default encode to utf8
content = "test extract archive tar\n".encode()
info = tarfile.TarInfo(name="tar_archive")
info.size = len(content)
with tarfile.open(tar_path, "w") as f:
f.addfile(info, io.BytesIO(content))
with tempfile.TemporaryDirectory() as dst_dir:
data.utils.extract_archive(tar_path, dst_dir, overwrite=True)
assert os.path.exists(os.path.join(dst_dir, tar_file))
def _test_construct_graphs_node_ids():
from dgl.data.csv_dataset_base import (
DGLGraphConstructor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment