Unverified Commit d9c25521 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Data] AsGraphPredDataset (#4073)

* Update

* CI

* Update

* Update

* Fix

* Fix
parent 9922f41f
......@@ -99,6 +99,7 @@ Dataset adapters
AsNodePredDataset
AsLinkPredDataset
AsGraphPredDataset
Utilities
-----------------
......
......@@ -389,6 +389,8 @@ After loaded, the dataset has multiple homographs with features and labels:
>>> print(data1)
{'feat': tensor([0.5348, 0.2864, 0.1155], dtype=torch.float64), 'label': tensor(0)}
If there is a single feature column in ``graphs.csv``, ``data0`` will directly be a tensor for the feature.
Custom Data Parser
~~~~~~~~~~~~~~~~~~
......
......@@ -78,13 +78,13 @@ for details of ``self._load_graph()`` and ``__getitem__``.
One can also add properties to the class to indicate some useful
information of the dataset. In :class:`~dgl.data.QM7bDataset`, one can add a property
``num_labels`` to indicate the total number of prediction tasks in this
``num_tasks`` to indicate the total number of prediction tasks in this
multi-task dataset:
.. code::
@property
def num_labels(self):
def num_tasks(self):
"""Number of labels for each graph, i.e. number of prediction tasks."""
return 14
......@@ -100,7 +100,7 @@ follows:
# load data
dataset = QM7bDataset()
num_labels = dataset.num_labels
num_tasks = dataset.num_tasks
# create dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
......@@ -178,7 +178,7 @@ part in below example for more details.
# node features
g.ndata['feat'] = torch.tensor(_preprocess_features(features),
dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1]
self._num_tasks = onehot_labels.shape[1]
self._labels = labels
# reorder graph to obtain better locality.
self._g = dgl.reorder_graph(g)
......
......@@ -66,12 +66,12 @@ DGL建议让 ``__getitem__(idx)`` 返回如上面代码所示的元组 ``(图,
以获得 ``self._load_graph()`` 和 ``__getitem__`` 的详细信息。
用户还可以向类添加属性以指示一些有用的数据集信息。在 :class:`~dgl.data.QM7bDataset` 中,
用户可以添加属性 ``num_labels`` 来指示此多任务数据集中的预测任务总数:
用户可以添加属性 ``num_tasks`` 来指示此多任务数据集中的预测任务总数:
.. code::
@property
def num_labels(self):
def num_tasks(self):
"""每个图的标签数,即预测任务数。"""
return 14
......@@ -86,7 +86,7 @@ DGL建议让 ``__getitem__(idx)`` 返回如上面代码所示的元组 ``(图,
# 数据导入
dataset = QM7bDataset()
num_labels = dataset.num_labels
num_tasks = dataset.num_tasks
# 创建 dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
......@@ -162,7 +162,7 @@ DGL提供了名为 :func:`dgl.reorder_graph` 的API用于此优化。更多细
# 节点的特征
g.ndata['feat'] = torch.tensor(_preprocess_features(features),
dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1]
self._num_tasks = onehot_labels.shape[1]
self._labels = labels
# 重排图以获得更优的局部性
self._g = dgl.reorder_graph(g)
......
......@@ -58,12 +58,12 @@
``process()`` 함수에서 처리되지 않은 데이터는 그래프들의 리스트와 레이블들의 리스트로 변환된다. Iteration을 위해서 ``__getitem__(idx)`` 와 ``__len__()`` 를 구현해야 한다. 위의 예제에서와 같이, DGL에서는 ``__getitem__(idx)`` 가 ``(graph, label)`` tuple을 리턴하도록 권장한다. ``self._load_graph()`` 와 ``__getitem__`` 함수의 구체적인 구현은 `QM7bDataset source
code <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/qm7b.html#QM7bDataset>`__ 를 확인하자.
데이터셋의 유용한 정보들을 지정하기 위해서 클래스에 프로퍼티들을 추가하는 것이 가능하다. :class:`~dgl.data.QM7bDataset` 에 이 멀티 테스크 데이터셋의 예측 테스트의 총 개숫를 지정하기 위해 ``num_labels`` 라는 프로퍼티를 추가할 수 있다.
데이터셋의 유용한 정보들을 지정하기 위해서 클래스에 프로퍼티들을 추가하는 것이 가능하다. :class:`~dgl.data.QM7bDataset` 에 이 멀티 테스크 데이터셋의 예측 테스트의 총 개숫를 지정하기 위해 ``num_tasks`` 라는 프로퍼티를 추가할 수 있다.
.. code::
@property
def num_labels(self):
def num_tasks(self):
"""Number of labels for each graph, i.e. number of prediction tasks."""
return 14
......@@ -78,7 +78,7 @@ code <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/qm7b.html#QM7bDataset>`__
# load data
dataset = QM7bDataset()
num_labels = dataset.num_labels
num_tasks = dataset.num_tasks
# create dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
......@@ -143,7 +143,7 @@ DGL의 빌트인 그래프 분류 데이터셋을 참고하면 그래프 분류
# node features
g.ndata['feat'] = torch.tensor(_preprocess_features(features),
dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1]
self._num_tasks = onehot_labels.shape[1]
self._labels = labels
# reorder graph to obtain better locality.
self._g = dgl.reorder_graph(g)
......
......@@ -30,7 +30,7 @@ from .rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from .fraud import FraudDataset, FraudYelpDataset, FraudAmazonDataset
from .fakenews import FakeNewsDataset
from .csv_dataset import CSVDataset
from .adapter import AsNodePredDataset, AsLinkPredDataset
from .adapter import *
from .synthetic import BAShapeDataset, BACommunityDataset, TreeCycleDataset, TreeGridDataset, BA2MotifDataset
def register_data_args(parser):
......
......@@ -9,16 +9,17 @@ from ..convert import graph as create_dgl_graph
from ..sampling.negative import _calc_redundancy
from .dgl_dataset import DGLDataset
from . import utils
from ..base import DGLError
from .. import backend as F
__all__ = ['AsNodePredDataset', 'AsLinkPredDataset']
__all__ = ['AsNodePredDataset', 'AsLinkPredDataset', 'AsGraphPredDataset']
class AsNodePredDataset(DGLDataset):
"""Repurpose a dataset for a standard semi-supervised transductive
node prediction task.
The class converts a given dataset into a new dataset object that:
The class converts a given dataset into a new dataset object such that:
- Contains only one graph, accessible from ``dataset[0]``.
- The graph stores:
......@@ -40,7 +41,7 @@ class AsNodePredDataset(DGLDataset):
So do validation and test masks.
The class will keep only the first graph in the provided dataset and
generate train/val/test masks according to the given spplit ratio. The generated
generate train/val/test masks according to the given split ratio. The generated
masks will be cached to disk for fast re-loading. If the provided split ratio
differs from the cached one, it will re-process the dataset properly.
......@@ -49,7 +50,7 @@ class AsNodePredDataset(DGLDataset):
dataset : DGLDataset
The dataset to be converted.
split_ratio : (float, float, float), optional
Split ratios for training, validation and test sets. Must sum to one.
Split ratios for training, validation and test sets. They must sum to one.
target_ntype : str, optional
The node type to add split mask for.
......@@ -193,7 +194,7 @@ class AsLinkPredDataset(DGLDataset):
"""Repurpose a dataset for link prediction task.
The created dataset will include data needed for link prediction.
Currently only support homogeneous graph.
Currently it only supports homogeneous graphs.
It will keep only the first graph in the provided dataset and
generate train/val/test edges according to the given split ratio,
and the correspondent negative edges based on the neg_ratio. The generated
......@@ -368,3 +369,158 @@ class AsLinkPredDataset(DGLDataset):
def __len__(self):
return 1
class AsGraphPredDataset(DGLDataset):
"""Repurpose a dataset for standard graph property prediction task.
The created dataset will include data needed for graph property prediction.
Currently it only supports homogeneous graphs.
The class converts a given dataset into a new dataset object such that:
- It stores ``len(dataset)`` graphs.
- The i-th graph and its label is accessible from ``dataset[i]``.
The class will generate a train/val/test split if :attr:`split_ratio` is provided.
The generated split will be cached to disk for fast re-loading. If the provided split
ratio differs from the cached one, it will re-process the dataset properly.
Parameters
----------
dataset : DGLDataset
The dataset to be converted.
split_ratio : (float, float, float), optional
Split ratios for training, validation and test sets. They must sum to one.
Attributes
----------
num_tasks : int
Number of tasks to predict.
num_classes : int
Number of classes to predict per task, None for regression datasets.
train_idx : Tensor
An 1-D integer tensor of training node IDs.
val_idx : Tensor
An 1-D integer tensor of validation node IDs.
test_idx : Tensor
An 1-D integer tensor of test node IDs.
node_feat_size : int
Input node feature size, None if not applicable.
edge_feat_size : int
Input edge feature size, None if not applicable.
Examples
--------
>>> from dgl.data import AsGraphPredDataset
>>> from ogb.graphproppred import DglGraphPropPredDataset
>>> dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
>>> new_dataset = AsGraphPredDataset(dataset)
>>> print(new_dataset)
Dataset("ogbg-molhiv-as-graphpred", num_graphs=41127, save_path=...)
>>> print(len(new_dataset))
41127
>>> print(new_dataset[0])
(Graph(num_nodes=19, num_edges=40,
ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)}), tensor([0]))
"""
def __init__(self,
dataset,
split_ratio=None,
**kwargs):
self.dataset = dataset
self.split_ratio = split_ratio
super().__init__(dataset.name + '-as-graphpred',
hash_key=(split_ratio, dataset.name, 'graphpred'), **kwargs)
def process(self):
is_ogb = hasattr(self.dataset, 'get_idx_split')
if self.split_ratio is None:
if is_ogb:
split = self.dataset.get_idx_split()
self.train_idx = split['train']
self.val_idx = split['valid']
self.test_idx = split['test']
else:
# Handle FakeNewsDataset
try:
self.train_idx = F.nonzero_1d(self.dataset.train_mask)
self.val_idx = F.nonzero_1d(self.dataset.val_mask)
self.test_idx = F.nonzero_1d(self.dataset.test_mask)
except:
raise DGLError('The input dataset does not have default train/val/test\
split. Please specify split_ratio to generate the split.')
else:
if self.verbose:
print('Generating train/val/test split...')
train_ratio, val_ratio, _ = self.split_ratio
num_graphs = len(self.dataset)
num_train = int(num_graphs * train_ratio)
num_val = int(num_graphs * val_ratio)
idx = np.random.permutation(num_graphs)
self.train_idx = F.tensor(idx[:num_train])
self.val_idx = F.tensor(idx[num_train: num_train + num_val])
self.test_idx = F.tensor(idx[num_train + num_val:])
if hasattr(self.dataset, 'num_classes'):
# GINDataset, MiniGCDataset, FakeNewsDataset, TUDataset,
# LegacyTUDataset, BA2MotifDataset
self.num_classes = self.dataset.num_classes
else:
# None for multi-label classification and regression
self.num_classes = None
if hasattr(self.dataset, 'num_tasks'):
# OGB datasets
self.num_tasks = self.dataset.num_tasks
else:
self.num_tasks = 1
def has_cache(self):
return os.path.isfile(os.path.join(self.save_path, 'info_{}.json'.format(self.hash)))
def load(self):
with open(os.path.join(self.save_path, 'info_{}.json'.format(self.hash)), 'r') as f:
info = json.load(f)
if info['split_ratio'] != self.split_ratio:
raise ValueError('Provided split ratio is different from the cached file. '
'Re-process the dataset.')
self.split_ratio = info['split_ratio']
self.num_tasks = info['num_tasks']
self.num_classes = info['num_classes']
split = np.load(os.path.join(self.save_path, 'split_{}.npz'.format(self.hash)))
self.train_idx = F.zerocopy_from_numpy(split['train_idx'])
self.val_idx = F.zerocopy_from_numpy(split['val_idx'])
self.test_idx = F.zerocopy_from_numpy(split['test_idx'])
def save(self):
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
with open(os.path.join(self.save_path, 'info_{}.json'.format(self.hash)), 'w') as f:
json.dump({
'split_ratio': self.split_ratio,
'num_tasks': self.num_tasks,
'num_classes': self.num_classes}, f)
np.savez(os.path.join(self.save_path, 'split_{}.npz'.format(self.hash)),
train_idx=F.zerocopy_to_numpy(self.train_idx),
val_idx=F.zerocopy_to_numpy(self.val_idx),
test_idx=F.zerocopy_to_numpy(self.test_idx))
def __getitem__(self, idx):
return self.dataset[idx]
def __len__(self):
return len(self.dataset)
@property
def node_feat_size(self):
g = self[0][0]
return g.ndata['feat'].shape[-1] if 'feat' in g.ndata else None
@property
def edge_feat_size(self):
g = self[0][0]
return g.edata['feat'].shape[-1] if 'feat' in g.edata else None
import os
import numpy as np
from .dgl_dataset import DGLDataset
from .utils import save_graphs, load_graphs
from .utils import save_graphs, load_graphs, Subset
from .. import backend as F
from ..base import DGLError
......@@ -120,6 +121,8 @@ class CSVDataset(DGLDataset):
# construct graphs
self.graphs, self.data = DGLGraphConstructor.construct_graphs(
node_data, edge_data, graph_data)
if len(self.data) == 1:
self.labels = list(self.data.values())[0]
def has_cache(self):
graph_path = os.path.join(self.save_path,
......@@ -141,14 +144,21 @@ class CSVDataset(DGLDataset):
graph_path = os.path.join(self.save_path,
self.name + '.bin')
self.graphs, self.data = load_graphs(graph_path)
if len(self.data) == 1:
self.labels = list(self.data.values())[0]
def __getitem__(self, i):
if F.is_tensor(i) and F.ndim(i) == 1:
return Subset(self, F.copy_to(i, F.cpu()))
if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])
if len(self.data) > 0:
if len(self.data) == 1:
return g, self.labels[i]
elif len(self.data) > 0:
data = {k: v[i] for (k, v) in self.data.items()}
return g, data
else:
......
......@@ -239,7 +239,7 @@ class GraphData(BaseData):
{('_V', '_E', '_V'): ([], [])})
for graph_id in graph_ids:
graphs.append(graphs_dict[graph_id])
data = {k: _tensor(v) for k, v in graph_data.data.items()}
data = {k: F.reshape(_tensor(v), (len(graphs), -1)) for k, v in graph_data.data.items()}
return graphs, data
......
......@@ -49,6 +49,11 @@ class GINDataset(DGLBuiltinDataset):
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_classes : int
Number of classes for multiclass classification
Examples
--------
>>> data = GINDataset(name='MUTAG', self_loop=False)
......@@ -362,3 +367,7 @@ class GINDataset(DGLBuiltinDataset):
if os.path.exists(graph_path) and os.path.exists(info_path):
return True
return False
@property
def num_classes(self):
return self.gclasses
......@@ -43,8 +43,10 @@ class QM7bDataset(DGLDataset):
Attributes
----------
num_tasks : int
Number of prediction tasks
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
(DEPRECATED, use num_tasks instead) Number of prediction tasks
Raises
------
......@@ -54,7 +56,7 @@ class QM7bDataset(DGLDataset):
Examples
--------
>>> data = QM7bDataset()
>>> data.num_labels
>>> data.num_tasks
14
>>>
>>> # iterate over the dataset
......@@ -117,9 +119,14 @@ class QM7bDataset(DGLDataset):
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name))
@property
def num_tasks(self):
"""Number of prediction tasks."""
return self.num_labels
@property
def num_labels(self):
"""Number of labels for each graph, i.e. number of prediction tasks."""
"""Number of prediction tasks."""
return 14
def __getitem__(self, idx):
......
......@@ -79,8 +79,10 @@ class QM9Dataset(DGLDataset):
Attributes
----------
num_tasks : int
Number of prediction tasks
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
(DEPRECATED, use num_tasks instead) Number of prediction tasks
Raises
------
......@@ -90,7 +92,7 @@ class QM9Dataset(DGLDataset):
Examples
--------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
>>> data.num_labels
>>> data.num_tasks
2
>>>
>>> # iterate over the dataset
......@@ -143,7 +145,17 @@ class QM9Dataset(DGLDataset):
Returns
--------
int
Number of labels for each graph, i.e. number of prediction tasks.
Number of prediction tasks.
"""
return self.label.shape[1]
@property
def num_tasks(self):
r"""
Returns
--------
int
Number of prediction tasks.
"""
return self.label.shape[1]
......
......@@ -105,8 +105,10 @@ class QM9EdgeDataset(DGLDataset):
Attributes
----------
num_tasks : int
Number of prediction tasks
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
(DEPRECATED, use num_tasks instead) Number of prediction tasks
Raises
------
......@@ -116,7 +118,7 @@ class QM9EdgeDataset(DGLDataset):
Examples
--------
>>> data = QM9EdgeDataset(label_keys=['mu', 'alpha'])
>>> data.num_labels
>>> data.num_tasks
2
>>> # iterate over the dataset
......@@ -245,5 +247,15 @@ class QM9EdgeDataset(DGLDataset):
"""
return self.n_node.shape[0]
@property
def num_tasks(self):
r"""
Returns
-------
int
Number of prediction tasks
"""
return self.num_labels
QM9Edge = QM9EdgeDataset
\ No newline at end of file
......@@ -692,7 +692,7 @@ class BA2MotifDataset(DGLBuiltinDataset):
Attributes
----------
num_classes : int
Number of node classes
Number of graph classes
Examples
--------
......
......@@ -35,8 +35,10 @@ class LegacyTUDataset(DGLBuiltinDataset):
----------
max_num_node : int
Maximum number of nodes
num_labels : int
num_classes : int
Number of classes
num_labels : numpy.int64
(DEPRECATED, use num_classes instead) Number of classes
Notes
-----
......@@ -244,6 +246,10 @@ class LegacyTUDataset(DGLBuiltinDataset):
self.num_labels,\
self.max_num_node
@property
def num_classes(self):
return int(self.num_labels)
class TUDataset(DGLBuiltinDataset):
r"""
TUDataset contains lots of graph kernel datasets for graph classification.
......@@ -262,8 +268,10 @@ class TUDataset(DGLBuiltinDataset):
----------
max_num_node : int
Maximum number of nodes
num_labels : int
num_classes : int
Number of classes
num_labels : int
(DEPRECATED, use num_classes instead) Number of classes
Notes
-----
......@@ -329,7 +337,7 @@ class TUDataset(DGLBuiltinDataset):
if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_reset(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
self.num_labels = max(DS_graph_labels) + 1
self.num_labels = int(max(DS_graph_labels) + 1)
self.graph_labels = F.tensor(DS_graph_labels)
elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = loadtxt(self._file_path("graph_attributes"), delimiter=",").astype(float)
......@@ -442,3 +450,7 @@ class TUDataset(DGLBuiltinDataset):
return self.graph_lists[0].ndata['feat'].shape[1], \
self.num_labels, \
self.max_num_node
@property
def num_classes(self):
return self.num_labels
......@@ -40,6 +40,7 @@ def test_gin():
ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False, transform=transform)
g2 = ds[0][0]
assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
assert ds.num_classes == ds.gclasses
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_fraud():
......@@ -83,6 +84,7 @@ def test_fakenews():
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_tudataset_regression():
ds = data.TUDataset('ZINC_test', force_reload=True)
assert ds.num_classes == ds.num_labels
assert len(ds) == 5000
g = ds[0][0]
......@@ -494,7 +496,7 @@ def _test_construct_graphs_multiple():
assert len(data_dict) == len(gdata)
for k, v in data_dict.items():
assert F.dtype(v) != F.float64
assert F.array_equal(F.tensor(gdata[k], dtype=F.dtype(v)), v)
assert F.array_equal(F.reshape(F.tensor(gdata[k], dtype=F.dtype(v)), (len(graphs), -1)), v)
for i, g in enumerate(graphs):
assert g.is_homogeneous
assert g.num_nodes() == num_nodes
......@@ -1102,7 +1104,7 @@ def _test_CSVDataset_customized_data_parser():
assert 'label' in csv_dataset.data
for i, (g, g_data) in enumerate(csv_dataset):
assert not g.is_homogeneous
assert F.asnumpy(g_data['label']) == label_gdata[i] + 2
assert F.asnumpy(g_data) == label_gdata[i] + 2
for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes
offset = 2 if ntype == 'user' else 0
......@@ -1122,7 +1124,7 @@ def _test_CSVDataset_customized_data_parser():
assert 'label' in csv_dataset.data
for i, (g, g_data) in enumerate(csv_dataset):
assert not g.is_homogeneous
assert F.asnumpy(g_data['label']) == label_gdata[i] + 2
assert F.asnumpy(g_data) == label_gdata[i] + 2
for ntype in g.ntypes:
assert g.num_nodes(ntype) == num_nodes
offset = 2
......@@ -1384,6 +1386,141 @@ def test_as_nodepred_csvdataset():
assert 'label' in new_ds[0].ndata
assert 'train_mask' in new_ds[0].ndata
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_as_graphpred():
ds = data.GINDataset(name='MUTAG', self_loop=True)
new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
assert len(new_ds) == 188
assert new_ds.num_tasks == 1
assert new_ds.num_classes == 2
ds = data.FakeNewsDataset('politifact', 'profile')
new_ds = data.AsGraphPredDataset(ds, verbose=True)
assert len(new_ds) == 314
assert new_ds.num_tasks == 1
assert new_ds.num_classes == 2
ds = data.QM7bDataset()
new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
assert len(new_ds) == 7211
assert new_ds.num_tasks == 14
assert new_ds.num_classes is None
ds = data.QM9Dataset(label_keys=['mu', 'gap'])
new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
assert len(new_ds) == 130831
assert new_ds.num_tasks == 2
assert new_ds.num_classes is None
ds = data.QM9EdgeDataset(label_keys=['mu', 'alpha'])
new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
assert len(new_ds) == 130831
assert new_ds.num_tasks == 2
assert new_ds.num_classes is None
ds = data.TUDataset('DD')
new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
assert len(new_ds) == 1178
assert new_ds.num_tasks == 1
assert new_ds.num_classes == 2
ds = data.LegacyTUDataset('DD')
new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
assert len(new_ds) == 1178
assert new_ds.num_tasks == 1
assert new_ds.num_classes == 2
ds = data.BA2MotifDataset()
new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
assert len(new_ds) == 1000
assert new_ds.num_tasks == 1
assert new_ds.num_classes == 2
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_as_graphpred_reprocess():
ds = data.AsGraphPredDataset(data.GINDataset(name='MUTAG', self_loop=True), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# read from cache
ds = data.AsGraphPredDataset(data.GINDataset(name='MUTAG', self_loop=True), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# invalid cache, re-read
ds = data.AsGraphPredDataset(data.GINDataset(name='MUTAG', self_loop=True), [0.1, 0.1, 0.8])
assert len(ds.train_idx) == int(len(ds) * 0.1)
ds = data.AsGraphPredDataset(data.FakeNewsDataset('politifact', 'profile'), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# read from cache
ds = data.AsGraphPredDataset(data.FakeNewsDataset('politifact', 'profile'), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# invalid cache, re-read
ds = data.AsGraphPredDataset(data.FakeNewsDataset('politifact', 'profile'), [0.1, 0.1, 0.8])
assert len(ds.train_idx) == int(len(ds) * 0.1)
ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# read from cache
ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# invalid cache, re-read
ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.1, 0.1, 0.8])
assert len(ds.train_idx) == int(len(ds) * 0.1)
ds = data.AsGraphPredDataset(data.QM9Dataset(label_keys=['mu', 'gap']), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# read from cache
ds = data.AsGraphPredDataset(data.QM9Dataset(label_keys=['mu', 'gap']), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# invalid cache, re-read
ds = data.AsGraphPredDataset(data.QM9Dataset(label_keys=['mu', 'gap']), [0.1, 0.1, 0.8])
assert len(ds.train_idx) == int(len(ds) * 0.1)
ds = data.AsGraphPredDataset(data.QM9EdgeDataset(label_keys=['mu', 'alpha']), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# read from cache
ds = data.AsGraphPredDataset(data.QM9EdgeDataset(label_keys=['mu', 'alpha']), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# invalid cache, re-read
ds = data.AsGraphPredDataset(data.QM9EdgeDataset(label_keys=['mu', 'alpha']), [0.1, 0.1, 0.8])
assert len(ds.train_idx) == int(len(ds) * 0.1)
ds = data.AsGraphPredDataset(data.TUDataset('DD'), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# read from cache
ds = data.AsGraphPredDataset(data.TUDataset('DD'), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# invalid cache, re-read
ds = data.AsGraphPredDataset(data.TUDataset('DD'), [0.1, 0.1, 0.8])
assert len(ds.train_idx) == int(len(ds) * 0.1)
ds = data.AsGraphPredDataset(data.LegacyTUDataset('DD'), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# read from cache
ds = data.AsGraphPredDataset(data.LegacyTUDataset('DD'), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# invalid cache, re-read
ds = data.AsGraphPredDataset(data.LegacyTUDataset('DD'), [0.1, 0.1, 0.8])
assert len(ds.train_idx) == int(len(ds) * 0.1)
ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# read from cache
ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])
assert len(ds.train_idx) == int(len(ds) * 0.8)
# invalid cache, re-read
ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.1, 0.1, 0.8])
assert len(ds.train_idx) == int(len(ds) * 0.1)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason="ogb only supports pytorch")
def test_as_graphpred_ogb():
from ogb.graphproppred import DglGraphPropPredDataset
ds = data.AsGraphPredDataset(DglGraphPropPredDataset('ogbg-molhiv'),
split_ratio=None, verbose=True)
assert len(ds.train_idx) == 32901
# force generate new split
ds = data.AsGraphPredDataset(DglGraphPropPredDataset('ogbg-molhiv'),
split_ratio=[0.6, 0.2, 0.2], verbose=True)
assert len(ds.train_idx) == 24676
if __name__ == '__main__':
test_minigc()
test_gin()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment