Unverified Commit 8b8fd2c0 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Dataset] Add transform argument to built-in datasets (#3733)

* Update

* Fix

* Update
parent b3d3a2c4
......@@ -39,6 +39,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
verbose: bool
Whether to print out progress information.
Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -67,12 +71,13 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
_url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz'
_sha1_str = 'c14281f9e252de0bd0b5f1c6e2bae03123938641'
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(BitcoinOTCDataset, self).__init__(name='bitcoinotc',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def download(self):
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
......@@ -143,7 +148,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
- ``edata['h']`` : edge weights
"""
return self.graphs[item]
if self._transform is None:
return self.graphs[item]
else:
return self._transform(self.graphs[item])
@property
def is_temporal(self):
......
......@@ -43,10 +43,14 @@ class CitationGraphDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
......@@ -54,7 +58,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, name, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
......@@ -69,7 +74,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
"""Loads input data from data directory and reorder graph for better locality
......@@ -213,7 +219,10 @@ class CitationGraphDataset(DGLBuiltinDataset):
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
if self._transform is None:
return self._g
else:
return self._transform(self._g)
def __len__(self):
return 1
......@@ -267,7 +276,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
@property
def reverse_edge(self):
return self._reverse_edge
def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation"""
......@@ -356,10 +365,14 @@ class CoraGraphDataset(CitationGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -400,10 +413,12 @@ class CoraGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label']
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
name = 'cora'
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -496,10 +511,14 @@ class CiteseerGraphDataset(CitationGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -543,10 +562,12 @@ class CiteseerGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label']
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None):
name = 'citeseer'
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -639,10 +660,14 @@ class PubmedGraphDataset(CitationGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -683,10 +708,12 @@ class PubmedGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label']
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
name = 'pubmed'
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -714,7 +741,7 @@ class PubmedGraphDataset(CitationGraphDataset):
r"""The number of graphs in the dataset."""
return super(PubmedGraphDataset, self).__len__()
def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None):
"""Get CoraGraphDataset
Parameters
......@@ -724,19 +751,24 @@ def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True)
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return
-------
CoraGraphDataset
"""
data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data
def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def load_citeseer(raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
"""Get CiteseerGraphDataset
Parameters
......@@ -746,38 +778,47 @@ def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=T
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return
-------
CiteseerGraphDataset
"""
data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data
def load_pubmed(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def load_pubmed(raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
"""Get PubmedGraphDataset
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return
-------
PubmedGraphDataset
"""
data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data
class CoraBinary(DGLBuiltinDataset):
......@@ -798,15 +839,20 @@ class CoraBinary(DGLBuiltinDataset):
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True, transform=None):
name = 'cora_binary'
url = _get_dgl_url('dataset/cora_binary.zip')
super(CoraBinary, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
root = self.raw_path
......@@ -894,7 +940,11 @@ class CoraBinary(DGLBuiltinDataset):
(dgl.DGLGraph, scipy.sparse.coo_matrix, int)
The graph, scipy sparse coo_matrix and its label.
"""
return (self.graphs[i], self.pmpds[i], self.labels[i])
if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])
return (g, self.pmpds[i], self.labels[i])
@property
def save_name(self):
......
......@@ -33,6 +33,10 @@ class DGLCSVDataset(DGLDataset):
A callable object which is used to parse corresponding column graph
data. Default: None. If None, a default data parser is applied
which load data directly and tries to convert list into array.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -46,7 +50,8 @@ class DGLCSVDataset(DGLDataset):
"""
META_YAML_NAME = 'meta.yaml'
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None, edge_data_parser=None, graph_data_parser=None):
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None,
edge_data_parser=None, graph_data_parser=None, transform=None):
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
self.graphs = None
self.data = None
......@@ -61,7 +66,7 @@ class DGLCSVDataset(DGLDataset):
self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)
ds_name = self.meta_yaml.dataset_name
super().__init__(ds_name, raw_dir=os.path.dirname(
meta_yaml_path), force_reload=force_reload, verbose=verbose)
meta_yaml_path), force_reload=force_reload, verbose=verbose, transform=transform)
def process(self):
......@@ -122,10 +127,15 @@ class DGLCSVDataset(DGLDataset):
self.graphs, self.data = load_graphs(graph_path)
def __getitem__(self, i):
if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])
if 'label' in self.data:
return self.graphs[i], self.data['label'][i]
return g, self.data['label'][i]
else:
return self.graphs[i]
return g
def __len__(self):
return len(self.graphs)
......@@ -49,6 +49,10 @@ class DGLDataset(object):
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -72,13 +76,14 @@ class DGLDataset(object):
Hash value for the dataset and the setting.
"""
def __init__(self, name, url=None, raw_dir=None, save_dir=None,
hash_key=(), force_reload=False, verbose=False):
hash_key=(), force_reload=False, verbose=False, transform=None):
self._name = name
self._url = url
self._force_reload = force_reload
self._verbose = verbose
self._hash_key = hash_key
self._hash = self._get_hash()
self._transform = transform
# if no dir is provided, the default dgl download dir is used.
if raw_dir is None:
......@@ -142,7 +147,7 @@ class DGLDataset(object):
def _download(self):
"""Download dataset by calling ``self.download()``
if the dataset does not exists under ``self.raw_path``.
By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``
One can overwrite ``raw_path()`` function to change the path.
"""
......@@ -161,7 +166,7 @@ class DGLDataset(object):
- If loadin process fails, re-download and process the dataset.
else:
- Download the dataset if needed.
- Process the dataset and build the dgl graph.
- Save the processed dataset into files.
......@@ -287,17 +292,23 @@ class DGLBuiltinDataset(DGLDataset):
from the same dataset class by comparing the hash values.
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: False
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, name, url, raw_dir=None, hash_key=(), force_reload=False, verbose=False):
def __init__(self, name, url, raw_dir=None, hash_key=(),
force_reload=False, verbose=False, transform=None):
super(DGLBuiltinDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
save_dir=None,
hash_key=hash_key,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def download(self):
r""" Automatically download data and extract it.
......
......@@ -76,6 +76,10 @@ class FakeNewsDataset(DGLBuiltinDataset):
downloaded data or the directory that
already stores the input data.
Default: ~/.dgl/
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -113,7 +117,7 @@ class FakeNewsDataset(DGLBuiltinDataset):
'politifact': 'dataset/FakeNewsPOL.zip'
}
def __init__(self, name, feature_name, raw_dir=None):
def __init__(self, name, feature_name, raw_dir=None, transform=None):
assert name in ['gossipcop', 'politifact'], \
"Only supports 'gossipcop' or 'politifact'."
url = _get_dgl_url(self.file_urls[name])
......@@ -123,7 +127,8 @@ class FakeNewsDataset(DGLBuiltinDataset):
self.feature_name = feature_name
super(FakeNewsDataset, self).__init__(name=name,
url=url,
raw_dir=raw_dir)
raw_dir=raw_dir,
transform=transform)
def process(self):
"""process raw data to graph, labels and masks"""
......@@ -213,7 +218,11 @@ class FakeNewsDataset(DGLBuiltinDataset):
-------
(:class:`dgl.DGLGraph`, Tensor)
"""
return self.graphs[i], self.labels[i]
if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])
return g, self.labels[i]
def __len__(self):
r"""Number of graphs in the dataset.
......
......@@ -48,8 +48,12 @@ class FraudDataset(DGLBuiltinDataset):
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -88,9 +92,9 @@ class FraudDataset(DGLBuiltinDataset):
'yelp': 'review',
'amazon': 'user'
}
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
val_size=0.1, force_reload=False, verbose=True, transform=None):
assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'"
url = _get_dgl_url(self.file_urls[name])
self.seed = random_seed
......@@ -101,30 +105,31 @@ class FraudDataset(DGLBuiltinDataset):
raw_dir=raw_dir,
hash_key=(random_seed, train_size, val_size),
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
"""process raw data to graph, labels, splitting masks"""
file_path = os.path.join(self.raw_path, self.file_names[self.name])
data = io.loadmat(file_path)
node_features = data['features'].todense()
# remove additional dimension of length 1 in raw .mat file
node_labels = data['label'].squeeze()
graph_data = {}
for relation in self.relations[self.name]:
adj = data[relation].tocoo()
row, col = adj.row, adj.col
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col)
g = heterograph(graph_data)
g.ndata['feature'] = F.tensor(node_features, dtype=F.data_type_dict['float32'])
g.ndata['label'] = F.tensor(node_labels, dtype=F.data_type_dict['int64'])
self.graph = g
self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size)
def __getitem__(self, idx):
r""" Get graph object
......@@ -145,12 +150,15 @@ class FraudDataset(DGLBuiltinDataset):
- ``ndata['test_mask']``: mask of testing set
"""
assert idx == 0, "This dataset has only one graph"
return self.graph
if self._transform is None:
return self.graph
else:
return self._transform(self.graph)
def __len__(self):
"""number of data examples"""
return len(self.graph)
@property
def num_classes(self):
"""Number of classes.
......@@ -160,37 +168,37 @@ class FraudDataset(DGLBuiltinDataset):
int
"""
return 2
def save(self):
"""save processed data to directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graph)
def load(self):
"""load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
graph_list, _ = load_graphs(str(graph_path))
g = graph_list[0]
self.graph = g
def has_cache(self):
"""check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
return os.path.exists(graph_path)
def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):
"""split the dataset into training set, validation set and testing set"""
assert 0 <= train_size + val_size <= 1, \
"The sum of valid training set size and validation set size " \
"must between 0 and 1 (inclusive)."
N = x.shape[0]
index = np.arange(N)
if self.name == 'amazon':
# 0-3304 are unlabeled nodes
index = np.arange(3305, N)
index = np.random.RandomState(seed).permutation(index)
train_idx = index[:int(train_size * len(index))]
val_idx = index[len(index) - int(val_size * len(index)):]
......@@ -254,8 +262,12 @@ class FraudYelpDataset(FraudDataset):
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Examples
--------
......@@ -265,16 +277,17 @@ class FraudYelpDataset(FraudDataset):
>>> feat = graph.ndata['feature']
>>> label = graph.ndata['label']
"""
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
val_size=0.1, force_reload=False, verbose=True, transform=None):
super(FraudYelpDataset, self).__init__(name='yelp',
raw_dir=raw_dir,
random_seed=random_seed,
train_size=train_size,
val_size=val_size,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
class FraudAmazonDataset(FraudDataset):
......@@ -330,8 +343,12 @@ class FraudAmazonDataset(FraudDataset):
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Examples
--------
......@@ -341,13 +358,14 @@ class FraudAmazonDataset(FraudDataset):
>>> feat = graph.ndata['feature']
>>> label = graph.ndata['label']
"""
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
val_size=0.1, force_reload=False, verbose=True, transform=None):
super(FraudAmazonDataset, self).__init__(name='amazon',
raw_dir=raw_dir,
random_seed=random_seed,
train_size=train_size,
val_size=val_size,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
......@@ -37,8 +37,12 @@ class GDELTDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -65,7 +69,8 @@ class GDELTDataset(DGLBuiltinDataset):
....
>>>
"""
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
def __init__(self, mode='train', raw_dir=None,
force_reload=False, verbose=False, transform=None):
mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid."
self.mode = mode
......@@ -75,7 +80,8 @@ class GDELTDataset(DGLBuiltinDataset):
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
file_path = os.path.join(self.raw_path, self.mode + '.txt')
......@@ -148,6 +154,8 @@ class GDELTDataset(DGLBuiltinDataset):
rate = self.data[row_mask][:, 1]
g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64'])
if self._transform is not None:
g = self._transform(g)
return g
def __len__(self):
......
......@@ -18,9 +18,9 @@ from ..convert import graph as dgl_graph
class GINDataset(DGLBuiltinDataset):
"""Dataset Class for `How Powerful Are Graph Neural Networks? <https://arxiv.org/abs/1810.00826>`_.
This is adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.
The class provides an interface for nine datasets used in the paper along with the paper-specific
settings. The datasets are ``'MUTAG'``, ``'COLLAB'``, ``'IMDBBINARY'``, ``'IMDBMULTI'``,
``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, ``'REDDITBINARY'``, ``'REDDITMULTI5K'``.
......@@ -44,6 +44,10 @@ class GINDataset(DGLBuiltinDataset):
add self to self edge if true
degree_as_nlabel: bool
take node degree as label and feature if true
transform: callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Examples
--------
......@@ -73,7 +77,7 @@ class GINDataset(DGLBuiltinDataset):
"""
def __init__(self, name, self_loop, degree_as_nlabel=False,
raw_dir=None, force_reload=False, verbose=False):
raw_dir=None, force_reload=False, verbose=False, transform=None):
self._name = name # MUTAG
gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
......@@ -106,7 +110,8 @@ class GINDataset(DGLBuiltinDataset):
self.nlabels_flag = False
super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel),
raw_dir=raw_dir, force_reload=force_reload, verbose=verbose)
raw_dir=raw_dir, force_reload=force_reload,
verbose=verbose, transform=transform)
@property
def raw_path(self):
......@@ -136,7 +141,11 @@ class GINDataset(DGLBuiltinDataset):
(:class:`dgl.Graph`, Tensor)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
if self._transform is None:
g = self.graphs[idx]
else:
g = self._transform(self.graphs[idx])
return g, self.labels[idx]
def _file_path(self):
return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name))
......
......@@ -27,13 +27,14 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
Reference: https://github.com/shchur/gnn-benchmark#datasets
"""
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None):
_url = _get_dgl_url('dataset/' + name + '.zip')
super(GNNBenchmarkDataset, self).__init__(name=name,
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
npz_path = os.path.join(self.raw_path, self.name + '.npz')
......@@ -128,7 +129,10 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
- ``ndata['label']``: node labels
"""
assert idx == 0, "This dataset has only one graph"
return self._graph
if self._transform is None:
return self._graph
else:
return self._transform(self._graph)
def __len__(self):
r"""Number of graphs in the dataset"""
......@@ -164,8 +168,12 @@ class CoraFullDataset(GNNBenchmarkDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -182,11 +190,12 @@ class CoraFullDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoraFullDataset, self).__init__(name="cora_full",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
@property
def num_classes(self):
......@@ -231,8 +240,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -249,11 +262,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoauthorCSDataset, self).__init__(name='coauthor_cs',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
@property
def num_classes(self):
......@@ -298,8 +312,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -316,11 +334,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoauthorPhysicsDataset, self).__init__(name='coauthor_physics',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
@property
def num_classes(self):
......@@ -364,8 +383,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -382,11 +405,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(AmazonCoBuyComputerDataset, self).__init__(name='amazon_co_buy_computer',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
@property
def num_classes(self):
......@@ -430,8 +454,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -448,11 +476,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(AmazonCoBuyPhotoDataset, self).__init__(name='amazon_co_buy_photo',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
@property
def num_classes(self):
......
......@@ -39,8 +39,12 @@ class ICEWS18Dataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
-------
......@@ -61,7 +65,7 @@ class ICEWS18Dataset(DGLBuiltinDataset):
....
>>>
"""
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None):
mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid"
self.mode = mode
......@@ -70,7 +74,8 @@ class ICEWS18Dataset(DGLBuiltinDataset):
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
data = loadtxt(os.path.join(self.save_path, '{}.txt'.format(self.mode)),
......@@ -118,7 +123,10 @@ class ICEWS18Dataset(DGLBuiltinDataset):
- ``edata['rel_type']``: edge type
"""
return self._graphs[idx]
if self._transform is None:
return self._graphs[idx]
else:
return self._transform(self._graphs[idx])
def __len__(self):
r"""Number of graphs in the dataset.
......
......@@ -34,6 +34,13 @@ class KarateClubDataset(DGLDataset):
- Edges: 156
- Number of Classes: 2
Parameters
----------
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_classes : int
......@@ -48,8 +55,8 @@ class KarateClubDataset(DGLDataset):
>>> g = dataset[0]
>>> labels = g.ndata['label']
"""
def __init__(self):
super(KarateClubDataset, self).__init__(name='karate_club')
def __init__(self, transform=None):
super(KarateClubDataset, self).__init__(name='karate_club', transform=transform)
def process(self):
kc_graph = nx.karate_club_graph()
......@@ -88,7 +95,10 @@ class KarateClubDataset(DGLDataset):
- ``ndata['label']``: ground truth labels
"""
assert idx == 0, "This dataset has only one graph"
return self._graph
if self._transform is None:
return self._graph
else:
return self._transform(self._graph)
def __len__(self):
r"""The number of graphs in the dataset."""
......
......@@ -25,19 +25,24 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
Parameters
-----------
name: str
name : str
Name can be 'FB15k-237', 'FB15k' or 'wn18'.
reverse: bool
reverse : bool
Whether add reverse edges. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
self._name = name
self.reverse = reverse
url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
......@@ -45,7 +50,8 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def download(self):
r""" Automatically download data and extract it.
......@@ -112,7 +118,10 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
if self._transform is None:
return self._g
else:
return self._transform(self._g)
def __len__(self):
return 1
......@@ -389,8 +398,12 @@ class FB15k237Dataset(KnowledgeGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -433,9 +446,11 @@ class FB15k237Dataset(KnowledgeGraphDataset):
>>>
>>> # Train, Validation and Test
"""
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
name = 'FB15k-237'
super(FB15k237Dataset, self).__init__(name, reverse, raw_dir, force_reload, verbose)
super(FB15k237Dataset, self).__init__(name, reverse, raw_dir,
force_reload, verbose, transform)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -526,8 +541,12 @@ class FB15kDataset(KnowledgeGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -570,9 +589,11 @@ class FB15kDataset(KnowledgeGraphDataset):
>>> # Train, Validation and Test
>>>
"""
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
name = 'FB15k'
super(FB15kDataset, self).__init__(name, reverse, raw_dir, force_reload, verbose)
super(FB15kDataset, self).__init__(name, reverse, raw_dir,
force_reload, verbose, transform)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -662,8 +683,12 @@ class WN18Dataset(KnowledgeGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -706,9 +731,11 @@ class WN18Dataset(KnowledgeGraphDataset):
>>> # Train, Validation and Test
>>>
"""
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
name = 'wn18'
super(WN18Dataset, self).__init__(name, reverse, raw_dir, force_reload, verbose)
super(WN18Dataset, self).__init__(name, reverse, raw_dir,
force_reload, verbose, transform)
def __getitem__(self, idx):
r"""Gets the graph object
......
......@@ -33,8 +33,12 @@ class MiniGCDataset(DGLDataset):
Minimum number of nodes for graphs
max_num_v: int
Maximum number of nodes for graphs
seed : int, default is 0
seed: int, default is 0
Random seed for data generation
transform: callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -75,7 +79,7 @@ class MiniGCDataset(DGLDataset):
"""
def __init__(self, num_graphs, min_num_v, max_num_v, seed=0,
save_graph=True, force_reload=False, verbose=False):
save_graph=True, force_reload=False, verbose=False, transform=None):
self.num_graphs = num_graphs
self.min_num_v = min_num_v
self.max_num_v = max_num_v
......@@ -84,7 +88,7 @@ class MiniGCDataset(DGLDataset):
super(MiniGCDataset, self).__init__(name="minigc", hash_key=(num_graphs, min_num_v, max_num_v, seed),
force_reload=force_reload,
verbose=verbose)
verbose=verbose, transform=transform)
def process(self):
self.graphs = []
......@@ -108,7 +112,11 @@ class MiniGCDataset(DGLDataset):
(:class:`dgl.Graph`, Tensor)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
if self._transform is None:
g = self.graphs[idx]
else:
g = self._transform(self.graphs[idx])
return g, self.labels[idx]
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
......
......@@ -56,9 +56,13 @@ class PPIDataset(DGLBuiltinDataset):
force_reload : bool
Whether to reload the dataset.
Default: False
verbose: bool
verbose : bool
Whether to print out progress information.
Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -79,7 +83,8 @@ class PPIDataset(DGLBuiltinDataset):
.... # your code here
>>>
"""
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
def __init__(self, mode='train', raw_dir=None, force_reload=False,
verbose=False, transform=None):
assert mode in ['train', 'valid', 'test']
self.mode = mode
_url = _get_dgl_url('dataset/ppi.zip')
......@@ -87,7 +92,8 @@ class PPIDataset(DGLBuiltinDataset):
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
graph_file = os.path.join(self.save_path, '{}_graph.json'.format(self.mode))
......@@ -178,7 +184,10 @@ class PPIDataset(DGLBuiltinDataset):
- ``ndata['feat']``: node features
- ``ndata['label']``: node labels
"""
return self.graphs[item]
if self._transform is None:
return self.graphs[item]
else:
return self._transform(self.graphs[item])
class LegacyPPIDataset(PPIDataset):
......@@ -198,5 +207,8 @@ class LegacyPPIDataset(PPIDataset):
(dgl.DGLGraph, Tensor, Tensor)
The graph, features and its label.
"""
return self.graphs[item], self.graphs[item].ndata['feat'], self.graphs[item].ndata['label']
if self._transform is None:
g = self.graphs[item]
else:
g = self._transform(self.graphs[item])
return g, g.ndata['feat'], g.ndata['label']
......@@ -34,8 +34,12 @@ class QM7bDataset(DGLDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -65,12 +69,13 @@ class QM7bDataset(DGLDataset):
'datasets/qm7b.mat'
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(QM7bDataset, self).__init__(name='qm7b',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
mat_path = self.raw_path + '.mat'
......@@ -129,7 +134,11 @@ class QM7bDataset(DGLDataset):
-------
(:class:`dgl.DGLGraph`, Tensor)
"""
return self.graphs[idx], self.label[idx]
if self._transform is None:
g = self.graphs[idx]
else:
g = self._transform(self.graphs[idx])
return g, self.label[idx]
def __len__(self):
r"""Number of graphs in the dataset.
......
......@@ -20,11 +20,11 @@ class QM9Dataset(DGLDataset):
2. It only provides atoms' coordinates and atomic numbers as node features
3. It only provides 12 regression targets.
Reference:
Reference:
- `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_,
- `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics:
- Number of graphs: 130,831
......@@ -60,9 +60,9 @@ class QM9Dataset(DGLDataset):
Parameters
----------
label_keys: list
label_keys : list
Names of the regression property, which should be a subset of the keys in the table above.
cutoff: float
cutoff : float
Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.
Default: 5.0 Angstrom
raw_dir : str
......@@ -70,8 +70,12 @@ class QM9Dataset(DGLDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -82,7 +86,7 @@ class QM9Dataset(DGLDataset):
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
......@@ -102,8 +106,9 @@ class QM9Dataset(DGLDataset):
cutoff=5.0,
raw_dir=None,
force_reload=False,
verbose=False):
verbose=False,
transform=None):
self.cutoff = cutoff
self.label_keys = label_keys
self._url = _get_dgl_url('dataset/qm9_eV.npz')
......@@ -112,7 +117,8 @@ class QM9Dataset(DGLDataset):
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
npz_path = f'{self.raw_dir}/qm9_eV.npz'
......@@ -148,7 +154,7 @@ class QM9Dataset(DGLDataset):
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
......@@ -170,8 +176,12 @@ class QM9Dataset(DGLDataset):
g = dgl_graph((u, v))
g = to_bidirected(g)
g.ndata['R'] = F.tensor(R, dtype=F.data_type_dict['float32'])
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
dtype=F.data_type_dict['int64'])
if self._transform is not None:
g = self._transform(g)
return g, label
def __len__(self):
......
......@@ -14,34 +14,34 @@ class QM9EdgeDataset(DGLDataset):
This dataset consists of 130,831 molecules with 19 regression targets.
Nodes correspond to atoms and edges correspond to bonds.
This dataset differs from :class:`~dgl.data.QM9Dataset` in the following aspects:
1. It includes the bonds in a molecule in the edges of the corresponding graph while the edges in :class:`~dgl.data.QM9Dataset` are purely distance-based.
2. It provides edge features, and node features in addition to the atoms' coordinates and atomic numbers.
3. It provides another 7 regression tasks(from 12 to 19).
This class is built based on a preprocessed version of the dataset, and we provide the preprocessing datails `here <https://gist.github.com/hengruizhang98/a2da30213b2356fff18b25385c9d3cd2>`_.
Reference:
- `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_
- `"Neural Message Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_
For
For
Statistics:
- Number of graphs: 130,831.
- Number of regression targets: 19.
Node attributes:
- pos: the 3D coordinates of each atom.
- attr: the 11D atom features.
- pos: the 3D coordinates of each atom.
- attr: the 11D atom features.
Edge attributes:
- edge_attr: the 4D bond features.
Regression targets:
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
......@@ -85,10 +85,10 @@ class QM9EdgeDataset(DGLDataset):
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| C | :math:`C` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Parameters
----------
label_keys: list
label_keys : list
Names of the regression property, which should be a subset of the keys in the table above.
If not provided, it will load all the labels.
raw_dir : str
......@@ -96,8 +96,12 @@ class QM9EdgeDataset(DGLDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False.
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -108,13 +112,13 @@ class QM9EdgeDataset(DGLDataset):
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9EdgeDataset(label_keys=['mu', 'alpha'])
>>> data.num_labels
2
>>> # iterate over the dataset
>>> for graph, labels in data:
... print(graph) # get information of each graph
......@@ -122,47 +126,49 @@ class QM9EdgeDataset(DGLDataset):
... # your code here...
>>>
"""
keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'U0_atom',
'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C']
map_dict = {}
for i, key in enumerate(keys):
map_dict[key] = i
def __init__(self,
def __init__(self,
label_keys=None,
raw_dir=None,
force_reload=False,
verbose=True):
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
if label_keys is None:
self.label_keys = None
self.num_labels = 19
else:
self.label_keys = [self.map_dict[i] for i in label_keys]
self.num_labels = len(label_keys)
self._url = _get_dgl_url('dataset/qm9_edge.npz')
super(QM9EdgeDataset, self).__init__(name='qm9Edge',
raw_dir=raw_dir,
url=self._url,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def download(self):
file_path = f'{self.raw_dir}/qm9_edge.npz'
if not os.path.exists(file_path):
download(self._url, path=file_path)
def process(self):
self.load()
def has_cache(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz'
return os.path.exists(npz_path)
def save(self):
np.savez_compressed(f'{self.raw_dir}/qm9_edge.npz',
n_node=self.n_node,
......@@ -171,7 +177,7 @@ class QM9EdgeDataset(DGLDataset):
node_pos=self.node_pos,
edge_attr=self.edge_attr,
src=self.src,
dst=self.dst,
dst=self.dst,
targets=self.targets)
def load(self):
......@@ -184,52 +190,55 @@ class QM9EdgeDataset(DGLDataset):
self.node_pos = data_dict['node_pos']
self.edge_attr = data_dict['edge_attr']
self.targets = data_dict['targets']
self.src = data_dict['src']
self.dst = data_dict['dst']
self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])
self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])
def __getitem__(self, idx):
r""" Get graph and label by index
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['pos']``: the coordinates of each atom
- ``ndata['attr']``: the features of each atom
- ``edata['edge_attr']``: the features of each bond
Tensor
Property values of molecular graphs
"""
pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]]
src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
g = dgl_graph((src, dst))
g.ndata['pos'] = F.tensor(pos, dtype=F.data_type_dict['float32'])
g.ndata['attr'] = F.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
g.edata['edge_attr'] = F.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
label = F.tensor(self.targets[idx][self.label_keys], dtype=F.data_type_dict['float32'])
if self._transform is not None:
g = self._transform(g)
return g, label
def __len__(self):
r""" Number of graphs in the dataset.
Returns
-------
int
......
......@@ -94,15 +94,20 @@ class RDFGraphDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool, optional
If true, force load and process from raw data. Ignore cached pre-processed data.
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, name, url, predict_category,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
verbose=True,
transform=None):
self._insert_reverse = insert_reverse
self._print_every = print_every
self._predict_category = predict_category
......@@ -110,7 +115,8 @@ class RDFGraphDataset(DGLBuiltinDataset):
super(RDFGraphDataset, self).__init__(name, url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
raw_tuples = self.load_raw_tuples(self.raw_path)
......@@ -409,6 +415,8 @@ class RDFGraphDataset(DGLBuiltinDataset):
r"""Gets the graph object
"""
g = self._hg
if self._transform is not None:
g = self._transform(g)
return g
def __len__(self):
......@@ -523,17 +531,21 @@ class AIFBDataset(RDFGraphDataset):
Parameters
-----------
print_every: int
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -562,7 +574,8 @@ class AIFBDataset(RDFGraphDataset):
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
verbose=True,
transform=None):
import rdflib as rdf
self.employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
self.affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
......@@ -574,7 +587,8 @@ class AIFBDataset(RDFGraphDataset):
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -653,17 +667,21 @@ class MUTAGDataset(RDFGraphDataset):
Parameters
-----------
print_every: int
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -697,7 +715,8 @@ class MUTAGDataset(RDFGraphDataset):
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
verbose=True,
transform=None):
import rdflib as rdf
self.is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
self.rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
......@@ -712,7 +731,8 @@ class MUTAGDataset(RDFGraphDataset):
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -814,17 +834,21 @@ class BGSDataset(RDFGraphDataset):
Parameters
-----------
print_every: int
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -854,7 +878,8 @@ class BGSDataset(RDFGraphDataset):
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
verbose=True,
transform=None):
import rdflib as rdf
url = _get_dgl_url('dataset/rdf/bgs-hetero.zip')
name = 'bgs-hetero'
......@@ -865,7 +890,8 @@ class BGSDataset(RDFGraphDataset):
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -964,17 +990,21 @@ class AMDataset(RDFGraphDataset):
Parameters
-----------
print_every: int
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -1003,7 +1033,8 @@ class AMDataset(RDFGraphDataset):
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
verbose=True,
transform=None):
import rdflib as rdf
self.objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
self.material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
......@@ -1015,7 +1046,8 @@ class AMDataset(RDFGraphDataset):
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def __getitem__(self, idx):
r"""Gets the graph object
......
......@@ -84,8 +84,12 @@ class RedditDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -125,7 +129,8 @@ class RedditDataset(DGLBuiltinDataset):
>>>
>>> # Train, Validation and Test
"""
def __init__(self, self_loop=False, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, self_loop=False, raw_dir=None, force_reload=False,
verbose=False, transform=None):
self_loop_str = ""
if self_loop:
self_loop_str = "_self_loop"
......@@ -135,7 +140,8 @@ class RedditDataset(DGLBuiltinDataset):
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
# graph
......@@ -251,7 +257,10 @@ class RedditDataset(DGLBuiltinDataset):
- ``ndata['test_mask']:`` mask for test node set
"""
assert idx == 0, "Reddit Dataset only has one graph"
return self._graph
if self._transform is None:
return self._graph
else:
return self._transform(self._graph)
def __len__(self):
r"""Number of graphs in the dataset"""
......
......@@ -23,7 +23,7 @@ class SSTDataset(DGLBuiltinDataset):
r"""Stanford Sentiment Treebank dataset.
.. deprecated:: 0.5.0
- ``trees`` is deprecated, it is replaced by:
>>> dataset = SSTDataset()
......@@ -63,8 +63,12 @@ class SSTDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
......@@ -120,7 +124,8 @@ class SSTDataset(DGLBuiltinDataset):
vocab_file=None,
raw_dir=None,
force_reload=False,
verbose=False):
verbose=False,
transform=None):
assert mode in ['train', 'dev', 'test', 'tiny']
_url = _get_dgl_url('dataset/sst.zip')
self._glove_embed_file = glove_embed_file if mode == 'train' else None
......@@ -130,7 +135,8 @@ class SSTDataset(DGLBuiltinDataset):
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
from nltk.corpus.reader import BracketParseCorpusReader
......@@ -255,7 +261,10 @@ class SSTDataset(DGLBuiltinDataset):
- ``ndata['y']:`` label of the node
- ``ndata['mask']``: 1 if the node is a leaf, otherwise 0
"""
return self._trees[idx]
if self._transform is None:
return self._trees[idx]
else:
return self._transform(self._trees[idx])
def __len__(self):
r"""Number of graphs in the dataset."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment