Unverified Commit 8b8fd2c0 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Dataset] Add transform argument to built-in datasets (#3733)

* Update

* Fix

* Update
parent b3d3a2c4
...@@ -13,7 +13,7 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -13,7 +13,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
Parameters Parameters
---------- ----------
name : str name : str
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_. datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
use_pandas : bool use_pandas : bool
Numpy's file read function has performance issue when file is large, Numpy's file read function has performance issue when file is large,
...@@ -26,6 +26,10 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -26,6 +26,10 @@ class LegacyTUDataset(DGLBuiltinDataset):
max_allow_node : int max_allow_node : int
Remove graphs that contains more nodes than ``max_allow_node``. Remove graphs that contains more nodes than ``max_allow_node``.
Default : None Default : None
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -39,7 +43,7 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -39,7 +43,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
LegacyTUDataset uses provided node feature by default. If no feature provided, it uses one-hot node label instead. LegacyTUDataset uses provided node feature by default. If no feature provided, it uses one-hot node label instead.
If neither labels provided, it uses constant for node feature. If neither labels provided, it uses constant for node feature.
The dataset sorts graphs by their labels. The dataset sorts graphs by their labels.
Shuffle is preferred before manual train/val split. Shuffle is preferred before manual train/val split.
Examples Examples
...@@ -73,7 +77,7 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -73,7 +77,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
def __init__(self, name, use_pandas=False, def __init__(self, name, use_pandas=False,
hidden_size=10, max_allow_node=None, hidden_size=10, max_allow_node=None,
raw_dir=None, force_reload=False, verbose=False): raw_dir=None, force_reload=False, verbose=False, transform=None):
url = self._url.format(name) url = self._url.format(name)
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -81,7 +85,7 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -81,7 +85,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
self.use_pandas = use_pandas self.use_pandas = use_pandas
super(LegacyTUDataset, self).__init__(name=name, url=url, raw_dir=raw_dir, super(LegacyTUDataset, self).__init__(name=name, url=url, raw_dir=raw_dir,
hash_key=(name, use_pandas, hidden_size, max_allow_node), hash_key=(name, use_pandas, hidden_size, max_allow_node),
force_reload=force_reload, verbose=verbose) force_reload=force_reload, verbose=verbose, transform=transform)
def process(self): def process(self):
self.data_mode = None self.data_mode = None
...@@ -100,7 +104,7 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -100,7 +104,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
DS_graph_labels = self._idx_from_zero( DS_graph_labels = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_labels"), dtype=int)) np.genfromtxt(self._file_path("graph_labels"), dtype=int))
self.num_labels = max(DS_graph_labels) + 1 self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels self.graph_labels = DS_graph_labels
elif os.path.exists(self._file_path("graph_attributes")): elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = np.genfromtxt(self._file_path("graph_attributes"), dtype=float) DS_graph_labels = np.genfromtxt(self._file_path("graph_attributes"), dtype=float)
self.num_labels = None self.num_labels = None
...@@ -211,6 +215,8 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -211,6 +215,8 @@ class LegacyTUDataset(DGLBuiltinDataset):
And its label. And its label.
""" """
g = self.graph_lists[idx] g = self.graph_lists[idx]
if self._transform is not None:
g = self._transform(g)
return g, self.graph_labels[idx] return g, self.graph_labels[idx]
def __len__(self): def __len__(self):
...@@ -245,8 +251,12 @@ class TUDataset(DGLBuiltinDataset): ...@@ -245,8 +251,12 @@ class TUDataset(DGLBuiltinDataset):
Parameters Parameters
---------- ----------
name : str name : str
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_. datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -271,7 +281,7 @@ class TUDataset(DGLBuiltinDataset): ...@@ -271,7 +281,7 @@ class TUDataset(DGLBuiltinDataset):
label was added so that :math:`\lbrace -1, 1 \rbrace` was mapped to label was added so that :math:`\lbrace -1, 1 \rbrace` was mapped to
:math:`\lbrace 0, 2 \rbrace`. :math:`\lbrace 0, 2 \rbrace`.
The dataset sorts graphs by their labels. The dataset sorts graphs by their labels.
Shuffle is preferred before manual train/val split. Shuffle is preferred before manual train/val split.
Examples Examples
...@@ -299,32 +309,32 @@ class TUDataset(DGLBuiltinDataset): ...@@ -299,32 +309,32 @@ class TUDataset(DGLBuiltinDataset):
Graph(num_nodes=9539, num_edges=47382, Graph(num_nodes=9539, num_edges=47382,
ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)} ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}) edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
""" """
_url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip" _url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False): def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None):
url = self._url.format(name) url = self._url.format(name)
super(TUDataset, self).__init__(name=name, url=url, super(TUDataset, self).__init__(name=name, url=url,
raw_dir=raw_dir, force_reload=force_reload, raw_dir=raw_dir, force_reload=force_reload,
verbose=verbose) verbose=verbose, transform=transform)
def process(self): def process(self):
DS_edge_list = self._idx_from_zero( DS_edge_list = self._idx_from_zero(
loadtxt(self._file_path("A"), delimiter=",").astype(int)) loadtxt(self._file_path("A"), delimiter=",").astype(int))
DS_indicator = self._idx_from_zero( DS_indicator = self._idx_from_zero(
loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int))
if os.path.exists(self._file_path("graph_labels")): if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_reset( DS_graph_labels = self._idx_reset(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
self.num_labels = max(DS_graph_labels) + 1 self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = F.tensor(DS_graph_labels) self.graph_labels = F.tensor(DS_graph_labels)
elif os.path.exists(self._file_path("graph_attributes")): elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = loadtxt(self._file_path("graph_attributes"), delimiter=",").astype(float) DS_graph_labels = loadtxt(self._file_path("graph_attributes"), delimiter=",").astype(float)
self.num_labels = None self.num_labels = None
self.graph_labels = F.tensor(DS_graph_labels) self.graph_labels = F.tensor(DS_graph_labels)
else: else:
raise Exception("Unknown graph label or graph attributes") raise Exception("Unknown graph label or graph attributes")
...@@ -404,6 +414,8 @@ class TUDataset(DGLBuiltinDataset): ...@@ -404,6 +414,8 @@ class TUDataset(DGLBuiltinDataset):
And its label. And its label.
""" """
g = self.graph_lists[idx] g = self.graph_lists[idx]
if self._transform is not None:
g = self._transform(g)
return g, self.graph_labels[idx] return g, self.graph_labels[idx]
def __len__(self): def __len__(self):
......
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
import pandas as pd import pandas as pd
import yaml import yaml
import pytest import pytest
import dgl
import dgl.data as data import dgl.data as data
from dgl import DGLError from dgl import DGLError
import dgl import dgl
...@@ -16,7 +17,11 @@ def test_minigc(): ...@@ -16,7 +17,11 @@ def test_minigc():
ds = data.MiniGCDataset(16, 10, 20) ds = data.MiniGCDataset(16, 10, 20)
g, l = list(zip(*ds)) g, l = list(zip(*ds))
print(g, l) print(g, l)
g1 = ds[0][0]
transform = dgl.AddSelfLoop(allow_duplicate=True)
ds = data.MiniGCDataset(16, 10, 20, transform=transform)
g2 = ds[0][0]
assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_gin(): def test_gin():
...@@ -27,37 +32,64 @@ def test_gin(): ...@@ -27,37 +32,64 @@ def test_gin():
'PROTEINS': 1113, 'PROTEINS': 1113,
'PTC': 344, 'PTC': 344,
} }
transform = dgl.AddSelfLoop(allow_duplicate=True)
for name, n_graphs in ds_n_graphs.items(): for name, n_graphs in ds_n_graphs.items():
ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False) ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False)
assert len(ds) == n_graphs, (len(ds), name) assert len(ds) == n_graphs, (len(ds), name)
g1 = ds[0][0]
ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False, transform=transform)
g2 = ds[0][0]
assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_fraud(): def test_fraud():
transform = dgl.AddSelfLoop(allow_duplicate=True)
g = data.FraudDataset('amazon')[0] g = data.FraudDataset('amazon')[0]
assert g.num_nodes() == 11944 assert g.num_nodes() == 11944
num_edges1 = g.num_edges()
g2 = data.FraudDataset('amazon', transform=transform)[0]
# 3 edge types
assert g2.num_edges() - num_edges1 == g.num_nodes() * 3
g = data.FraudAmazonDataset()[0] g = data.FraudAmazonDataset()[0]
assert g.num_nodes() == 11944 assert g.num_nodes() == 11944
g2 = data.FraudAmazonDataset(transform=transform)[0]
# 3 edge types
assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
g = data.FraudYelpDataset()[0] g = data.FraudYelpDataset()[0]
assert g.num_nodes() == 45954 assert g.num_nodes() == 45954
g2 = data.FraudYelpDataset(transform=transform)[0]
# 3 edge types
assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_fakenews(): def test_fakenews():
transform = dgl.AddSelfLoop(allow_duplicate=True)
ds = data.FakeNewsDataset('politifact', 'bert') ds = data.FakeNewsDataset('politifact', 'bert')
assert len(ds) == 314 assert len(ds) == 314
g = ds[0][0]
g2 = data.FakeNewsDataset('politifact', 'bert', transform=transform)[0][0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
ds = data.FakeNewsDataset('gossipcop', 'profile') ds = data.FakeNewsDataset('gossipcop', 'profile')
assert len(ds) == 5464 assert len(ds) == 5464
g = ds[0][0]
g2 = data.FakeNewsDataset('gossipcop', 'profile', transform=transform)[0][0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_tudataset_regression(): def test_tudataset_regression():
ds = data.TUDataset('ZINC_test', force_reload=True) ds = data.TUDataset('ZINC_test', force_reload=True)
assert len(ds) == 5000 assert len(ds) == 5000
g = ds[0][0]
transform = dgl.AddSelfLoop(allow_duplicate=True)
ds = data.TUDataset('ZINC_test', force_reload=True, transform=transform)
g2 = ds[0][0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_data_hash(): def test_data_hash():
...@@ -78,12 +110,16 @@ def test_data_hash(): ...@@ -78,12 +110,16 @@ def test_data_hash():
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_citation_graph(): def test_citation_graph():
transform = dgl.AddSelfLoop(allow_duplicate=True)
# cora # cora
g = data.CoraGraphDataset()[0] g = data.CoraGraphDataset()[0]
assert g.num_nodes() == 2708 assert g.num_nodes() == 2708
assert g.num_edges() == 10556 assert g.num_edges() == 10556
dst = F.asnumpy(g.edges()[1]) dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst)) assert np.array_equal(dst, np.sort(dst))
g2 = data.CoraGraphDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# Citeseer # Citeseer
g = data.CiteseerGraphDataset()[0] g = data.CiteseerGraphDataset()[0]
...@@ -91,6 +127,8 @@ def test_citation_graph(): ...@@ -91,6 +127,8 @@ def test_citation_graph():
assert g.num_edges() == 9228 assert g.num_edges() == 9228
dst = F.asnumpy(g.edges()[1]) dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst)) assert np.array_equal(dst, np.sort(dst))
g2 = data.CiteseerGraphDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# Pubmed # Pubmed
g = data.PubmedGraphDataset()[0] g = data.PubmedGraphDataset()[0]
...@@ -98,16 +136,22 @@ def test_citation_graph(): ...@@ -98,16 +136,22 @@ def test_citation_graph():
assert g.num_edges() == 88651 assert g.num_edges() == 88651
dst = F.asnumpy(g.edges()[1]) dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst)) assert np.array_equal(dst, np.sort(dst))
g2 = data.PubmedGraphDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_gnn_benchmark(): def test_gnn_benchmark():
transform = dgl.AddSelfLoop(allow_duplicate=True)
# AmazonCoBuyComputerDataset # AmazonCoBuyComputerDataset
g = data.AmazonCoBuyComputerDataset()[0] g = data.AmazonCoBuyComputerDataset()[0]
assert g.num_nodes() == 13752 assert g.num_nodes() == 13752
assert g.num_edges() == 491722 assert g.num_edges() == 491722
dst = F.asnumpy(g.edges()[1]) dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst)) assert np.array_equal(dst, np.sort(dst))
g2 = data.AmazonCoBuyComputerDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# AmazonCoBuyPhotoDataset # AmazonCoBuyPhotoDataset
g = data.AmazonCoBuyPhotoDataset()[0] g = data.AmazonCoBuyPhotoDataset()[0]
...@@ -115,6 +159,8 @@ def test_gnn_benchmark(): ...@@ -115,6 +159,8 @@ def test_gnn_benchmark():
assert g.num_edges() == 238163 assert g.num_edges() == 238163
dst = F.asnumpy(g.edges()[1]) dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst)) assert np.array_equal(dst, np.sort(dst))
g2 = data.AmazonCoBuyPhotoDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# CoauthorPhysicsDataset # CoauthorPhysicsDataset
g = data.CoauthorPhysicsDataset()[0] g = data.CoauthorPhysicsDataset()[0]
...@@ -122,6 +168,8 @@ def test_gnn_benchmark(): ...@@ -122,6 +168,8 @@ def test_gnn_benchmark():
assert g.num_edges() == 495924 assert g.num_edges() == 495924
dst = F.asnumpy(g.edges()[1]) dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst)) assert np.array_equal(dst, np.sort(dst))
g2 = data.CoauthorPhysicsDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# CoauthorCSDataset # CoauthorCSDataset
g = data.CoauthorCSDataset()[0] g = data.CoauthorCSDataset()[0]
...@@ -129,6 +177,8 @@ def test_gnn_benchmark(): ...@@ -129,6 +177,8 @@ def test_gnn_benchmark():
assert g.num_edges() == 163788 assert g.num_edges() == 163788
dst = F.asnumpy(g.edges()[1]) dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst)) assert np.array_equal(dst, np.sort(dst))
g2 = data.CoauthorCSDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
# CoraFullDataset # CoraFullDataset
g = data.CoraFullDataset()[0] g = data.CoraFullDataset()[0]
...@@ -136,6 +186,8 @@ def test_gnn_benchmark(): ...@@ -136,6 +186,8 @@ def test_gnn_benchmark():
assert g.num_edges() == 126842 assert g.num_edges() == 126842
dst = F.asnumpy(g.edges()[1]) dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst)) assert np.array_equal(dst, np.sort(dst))
g2 = data.CoraFullDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
...@@ -147,6 +199,10 @@ def test_reddit(): ...@@ -147,6 +199,10 @@ def test_reddit():
dst = F.asnumpy(g.edges()[1]) dst = F.asnumpy(g.edges()[1])
assert np.array_equal(dst, np.sort(dst)) assert np.array_equal(dst, np.sort(dst))
transform = dgl.AddSelfLoop(allow_duplicate=True)
g2 = data.RedditDataset(transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_extract_archive(): def test_extract_archive():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment