Unverified Commit 8b8fd2c0 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Dataset] Add transform argument to built-in datasets (#3733)

* Update

* Fix

* Update
parent b3d3a2c4
......@@ -13,7 +13,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
Parameters
----------
name : str
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
use_pandas : bool
Numpy's file read function has performance issue when file is large,
......@@ -26,6 +26,10 @@ class LegacyTUDataset(DGLBuiltinDataset):
max_allow_node : int
Remove graphs that contains more nodes than ``max_allow_node``.
Default : None
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -39,7 +43,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
LegacyTUDataset uses provided node feature by default. If no feature provided, it uses one-hot node label instead.
If neither labels provided, it uses constant for node feature.
The dataset sorts graphs by their labels.
The dataset sorts graphs by their labels.
Shuffle is preferred before manual train/val split.
Examples
......@@ -73,7 +77,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
def __init__(self, name, use_pandas=False,
hidden_size=10, max_allow_node=None,
raw_dir=None, force_reload=False, verbose=False):
raw_dir=None, force_reload=False, verbose=False, transform=None):
url = self._url.format(name)
self.hidden_size = hidden_size
......@@ -81,7 +85,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
self.use_pandas = use_pandas
super(LegacyTUDataset, self).__init__(name=name, url=url, raw_dir=raw_dir,
hash_key=(name, use_pandas, hidden_size, max_allow_node),
force_reload=force_reload, verbose=verbose)
force_reload=force_reload, verbose=verbose, transform=transform)
def process(self):
self.data_mode = None
......@@ -100,7 +104,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
DS_graph_labels = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_labels"), dtype=int))
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels
self.graph_labels = DS_graph_labels
elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = np.genfromtxt(self._file_path("graph_attributes"), dtype=float)
self.num_labels = None
......@@ -211,6 +215,8 @@ class LegacyTUDataset(DGLBuiltinDataset):
And its label.
"""
g = self.graph_lists[idx]
if self._transform is not None:
g = self._transform(g)
return g, self.graph_labels[idx]
def __len__(self):
......@@ -245,8 +251,12 @@ class TUDataset(DGLBuiltinDataset):
Parameters
----------
name : str
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -271,7 +281,7 @@ class TUDataset(DGLBuiltinDataset):
label was added so that :math:`\lbrace -1, 1 \rbrace` was mapped to
:math:`\lbrace 0, 2 \rbrace`.
The dataset sorts graphs by their labels.
The dataset sorts graphs by their labels.
Shuffle is preferred before manual train/val split.
Examples
......@@ -299,32 +309,32 @@ class TUDataset(DGLBuiltinDataset):
Graph(num_nodes=9539, num_edges=47382,
ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
"""
_url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None):
url = self._url.format(name)
super(TUDataset, self).__init__(name=name, url=url,
raw_dir=raw_dir, force_reload=force_reload,
verbose=verbose)
verbose=verbose, transform=transform)
def process(self):
DS_edge_list = self._idx_from_zero(
loadtxt(self._file_path("A"), delimiter=",").astype(int))
DS_indicator = self._idx_from_zero(
loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int))
if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_reset(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = F.tensor(DS_graph_labels)
self.graph_labels = F.tensor(DS_graph_labels)
elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = loadtxt(self._file_path("graph_attributes"), delimiter=",").astype(float)
self.num_labels = None
self.graph_labels = F.tensor(DS_graph_labels)
self.graph_labels = F.tensor(DS_graph_labels)
else:
raise Exception("Unknown graph label or graph attributes")
......@@ -404,6 +414,8 @@ class TUDataset(DGLBuiltinDataset):
And its label.
"""
g = self.graph_lists[idx]
if self._transform is not None:
g = self._transform(g)
return g, self.graph_labels[idx]
def __len__(self):
......
......@@ -7,6 +7,7 @@ import os
import pandas as pd
import yaml
import pytest
import dgl
import dgl.data as data
from dgl import DGLError
import dgl
......@@ -16,7 +17,11 @@ def test_minigc():
ds = data.MiniGCDataset(16, 10, 20)
g, l = list(zip(*ds))
print(g, l)
g1 = ds[0][0]
transform = dgl.AddSelfLoop(allow_duplicate=True)
ds = data.MiniGCDataset(16, 10, 20, transform=transform)
g2 = ds[0][0]
assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_gin():
......@@ -27,37 +32,64 @@ def test_gin():
'PROTEINS': 1113,
'PTC': 344,
}
transform = dgl.AddSelfLoop(allow_duplicate=True)
for name, n_graphs in ds_n_graphs.items():
ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False)
assert len(ds) == n_graphs, (len(ds), name)
g1 = ds[0][0]
ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False, transform=transform)
g2 = ds[0][0]
assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_fraud():
transform = dgl.AddSelfLoop(allow_duplicate=True)
g = data.FraudDataset('amazon')[0]
assert g.num_nodes() == 11944
num_edges1 = g.num_edges()
g2 = data.FraudDataset('amazon', transform=transform)[0]
# 3 edge types
assert g2.num_edges() - num_edges1 == g.num_nodes() * 3
g = data.FraudAmazonDataset()[0]
assert g.num_nodes() == 11944
g2 = data.FraudAmazonDataset(transform=transform)[0]
# 3 edge types
assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
g = data.FraudYelpDataset()[0]
assert g.num_nodes() == 45954
g2 = data.FraudYelpDataset(transform=transform)[0]
# 3 edge types
assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_fakenews():
transform = dgl.AddSelfLoop(allow_duplicate=True)
ds = data.FakeNewsDataset('politifact', 'bert')
assert len(ds) == 314
g = ds[0][0]
g2 = data.FakeNewsDataset('politifact', 'bert', transform=transform)[0][0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
ds = data.FakeNewsDataset('gossipcop', 'profile')
assert len(ds) == 5464
g = ds[0][0]
g2 = data.FakeNewsDataset('gossipcop', 'profile', transform=transform)[0][0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_tudataset_regression():
ds = data.TUDataset('ZINC_test', force_reload=True)
assert len(ds) == 5000
g = ds[0][0]
transform = dgl.AddSelfLoop(allow_duplicate=True)
ds = data.TUDataset('ZINC_test', force_reload=True, transform=transform)
g2 = ds[0][0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_data_hash():
......@@ -78,12 +110,16 @@ def test_data_hash():
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_citation_graph():
transform = dgl.AddSelfLoop(allow_duplicate=True)
# cora
g = data.CoraGraphDataset()[0]
assert g.num_nodes() == 2708
assert g.num_edges() == 10556
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
g2 = data.CoraGraphDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# Citeseer
g = data.CiteseerGraphDataset()[0]
......@@ -91,6 +127,8 @@ def test_citation_graph():
assert g.num_edges() == 9228
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
g2 = data.CiteseerGraphDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# Pubmed
g = data.PubmedGraphDataset()[0]
......@@ -98,16 +136,22 @@ def test_citation_graph():
assert g.num_edges() == 88651
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
g2 = data.PubmedGraphDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_gnn_benchmark():
transform = dgl.AddSelfLoop(allow_duplicate=True)
# AmazonCoBuyComputerDataset
g = data.AmazonCoBuyComputerDataset()[0]
assert g.num_nodes() == 13752
assert g.num_edges() == 491722
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
g2 = data.AmazonCoBuyComputerDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# AmazonCoBuyPhotoDataset
g = data.AmazonCoBuyPhotoDataset()[0]
......@@ -115,6 +159,8 @@ def test_gnn_benchmark():
assert g.num_edges() == 238163
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
g2 = data.AmazonCoBuyPhotoDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# CoauthorPhysicsDataset
g = data.CoauthorPhysicsDataset()[0]
......@@ -122,6 +168,8 @@ def test_gnn_benchmark():
assert g.num_edges() == 495924
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
g2 = data.CoauthorPhysicsDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# CoauthorCSDataset
g = data.CoauthorCSDataset()[0]
......@@ -129,6 +177,8 @@ def test_gnn_benchmark():
assert g.num_edges() == 163788
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
g2 = data.CoauthorCSDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# CoraFullDataset
g = data.CoraFullDataset()[0]
......@@ -136,6 +186,8 @@ def test_gnn_benchmark():
assert g.num_edges() == 126842
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
g2 = data.CoraFullDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
......@@ -147,6 +199,10 @@ def test_reddit():
dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst))
transform = dgl.AddSelfLoop(allow_duplicate=True)
g2 = data.RedditDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_extract_archive():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment