Unverified Commit 8b8fd2c0 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Dataset] Add transform argument to built-in datasets (#3733)

* Update

* Fix

* Update
parent b3d3a2c4
...@@ -39,6 +39,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -39,6 +39,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
verbose: bool verbose: bool
Whether to print out progress information. Whether to print out progress information.
Default: True. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -67,12 +71,13 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -67,12 +71,13 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
_url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz' _url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz'
_sha1_str = 'c14281f9e252de0bd0b5f1c6e2bae03123938641' _sha1_str = 'c14281f9e252de0bd0b5f1c6e2bae03123938641'
def __init__(self, raw_dir=None, force_reload=False, verbose=False): def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(BitcoinOTCDataset, self).__init__(name='bitcoinotc', super(BitcoinOTCDataset, self).__init__(name='bitcoinotc',
url=self._url, url=self._url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def download(self): def download(self):
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz') gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
...@@ -143,7 +148,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset): ...@@ -143,7 +148,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
- ``edata['h']`` : edge weights - ``edata['h']`` : edge weights
""" """
return self.graphs[item] if self._transform is None:
return self.graphs[item]
else:
return self._transform(self.graphs[item])
@property @property
def is_temporal(self): def is_temporal(self):
......
...@@ -43,10 +43,14 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -43,10 +43,14 @@ class CitationGraphDataset(DGLBuiltinDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool reverse_edge : bool
Whether to add reverse edges in graph. Default: True. Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
""" """
_urls = { _urls = {
'cora_v2' : 'dataset/cora_v2.zip', 'cora_v2' : 'dataset/cora_v2.zip',
...@@ -54,7 +58,8 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -54,7 +58,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
'pubmed' : 'dataset/pubmed.zip', 'pubmed' : 'dataset/pubmed.zip',
} }
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True): def __init__(self, name, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None):
assert name.lower() in ['cora', 'citeseer', 'pubmed'] assert name.lower() in ['cora', 'citeseer', 'pubmed']
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn) # Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
...@@ -69,7 +74,8 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -69,7 +74,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
url=url, url=url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
"""Loads input data from data directory and reorder graph for better locality """Loads input data from data directory and reorder graph for better locality
...@@ -213,7 +219,10 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -213,7 +219,10 @@ class CitationGraphDataset(DGLBuiltinDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
return self._g if self._transform is None:
return self._g
else:
return self._transform(self._g)
def __len__(self): def __len__(self):
return 1 return 1
...@@ -267,7 +276,7 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -267,7 +276,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
@property @property
def reverse_edge(self): def reverse_edge(self):
return self._reverse_edge return self._reverse_edge
def _preprocess_features(features): def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation""" """Row-normalize feature matrix and convert to tuple representation"""
...@@ -356,10 +365,14 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -356,10 +365,14 @@ class CoraGraphDataset(CitationGraphDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool reverse_edge : bool
Whether to add reverse edges in graph. Default: True. Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -400,10 +413,12 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -400,10 +413,12 @@ class CoraGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label'] >>> label = g.ndata['label']
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True): def __init__(self, raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
name = 'cora' name = 'cora'
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge) super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -496,10 +511,14 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -496,10 +511,14 @@ class CiteseerGraphDataset(CitationGraphDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool reverse_edge : bool
Whether to add reverse edges in graph. Default: True. Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -543,10 +562,12 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -543,10 +562,12 @@ class CiteseerGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label'] >>> label = g.ndata['label']
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True): def __init__(self, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None):
name = 'citeseer' name = 'citeseer'
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge) super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -639,10 +660,14 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -639,10 +660,14 @@ class PubmedGraphDataset(CitationGraphDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool reverse_edge : bool
Whether to add reverse edges in graph. Default: True. Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -683,10 +708,12 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -683,10 +708,12 @@ class PubmedGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label'] >>> label = g.ndata['label']
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True): def __init__(self, raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
name = 'pubmed' name = 'pubmed'
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge) super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -714,7 +741,7 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -714,7 +741,7 @@ class PubmedGraphDataset(CitationGraphDataset):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(PubmedGraphDataset, self).__len__() return super(PubmedGraphDataset, self).__len__()
def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True): def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None):
"""Get CoraGraphDataset """Get CoraGraphDataset
Parameters Parameters
...@@ -724,19 +751,24 @@ def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True) ...@@ -724,19 +751,24 @@ def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True)
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool reverse_edge : bool
Whether to add reverse edges in graph. Default: True. Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return Return
------- -------
CoraGraphDataset CoraGraphDataset
""" """
data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge) data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data return data
def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True): def load_citeseer(raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
"""Get CiteseerGraphDataset """Get CiteseerGraphDataset
Parameters Parameters
...@@ -746,38 +778,47 @@ def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=T ...@@ -746,38 +778,47 @@ def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=T
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool reverse_edge : bool
Whether to add reverse edges in graph. Default: True. Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return Return
------- -------
CiteseerGraphDataset CiteseerGraphDataset
""" """
data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge) data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data return data
def load_pubmed(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True): def load_pubmed(raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
"""Get PubmedGraphDataset """Get PubmedGraphDataset
Parameters Parameters
----------- -----------
raw_dir : str raw_dir : str
Raw file directory to download/contains the input data directory. Raw file directory to download/contains the input data directory.
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
reverse_edge: bool reverse_edge : bool
Whether to add reverse edges in graph. Default: True. Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return Return
------- -------
PubmedGraphDataset PubmedGraphDataset
""" """
data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge) data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data return data
class CoraBinary(DGLBuiltinDataset): class CoraBinary(DGLBuiltinDataset):
...@@ -798,15 +839,20 @@ class CoraBinary(DGLBuiltinDataset): ...@@ -798,15 +839,20 @@ class CoraBinary(DGLBuiltinDataset):
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose: bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=True): def __init__(self, raw_dir=None, force_reload=False, verbose=True, transform=None):
name = 'cora_binary' name = 'cora_binary'
url = _get_dgl_url('dataset/cora_binary.zip') url = _get_dgl_url('dataset/cora_binary.zip')
super(CoraBinary, self).__init__(name, super(CoraBinary, self).__init__(name,
url=url, url=url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
root = self.raw_path root = self.raw_path
...@@ -894,7 +940,11 @@ class CoraBinary(DGLBuiltinDataset): ...@@ -894,7 +940,11 @@ class CoraBinary(DGLBuiltinDataset):
(dgl.DGLGraph, scipy.sparse.coo_matrix, int) (dgl.DGLGraph, scipy.sparse.coo_matrix, int)
The graph, scipy sparse coo_matrix and its label. The graph, scipy sparse coo_matrix and its label.
""" """
return (self.graphs[i], self.pmpds[i], self.labels[i]) if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])
return (g, self.pmpds[i], self.labels[i])
@property @property
def save_name(self): def save_name(self):
......
...@@ -33,6 +33,10 @@ class DGLCSVDataset(DGLDataset): ...@@ -33,6 +33,10 @@ class DGLCSVDataset(DGLDataset):
A callable object which is used to parse corresponding column graph A callable object which is used to parse corresponding column graph
data. Default: None. If None, a default data parser is applied data. Default: None. If None, a default data parser is applied
which load data directly and tries to convert list into array. which load data directly and tries to convert list into array.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -46,7 +50,8 @@ class DGLCSVDataset(DGLDataset): ...@@ -46,7 +50,8 @@ class DGLCSVDataset(DGLDataset):
""" """
META_YAML_NAME = 'meta.yaml' META_YAML_NAME = 'meta.yaml'
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None, edge_data_parser=None, graph_data_parser=None): def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None,
edge_data_parser=None, graph_data_parser=None, transform=None):
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
self.graphs = None self.graphs = None
self.data = None self.data = None
...@@ -61,7 +66,7 @@ class DGLCSVDataset(DGLDataset): ...@@ -61,7 +66,7 @@ class DGLCSVDataset(DGLDataset):
self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path) self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)
ds_name = self.meta_yaml.dataset_name ds_name = self.meta_yaml.dataset_name
super().__init__(ds_name, raw_dir=os.path.dirname( super().__init__(ds_name, raw_dir=os.path.dirname(
meta_yaml_path), force_reload=force_reload, verbose=verbose) meta_yaml_path), force_reload=force_reload, verbose=verbose, transform=transform)
def process(self): def process(self):
...@@ -122,10 +127,15 @@ class DGLCSVDataset(DGLDataset): ...@@ -122,10 +127,15 @@ class DGLCSVDataset(DGLDataset):
self.graphs, self.data = load_graphs(graph_path) self.graphs, self.data = load_graphs(graph_path)
def __getitem__(self, i): def __getitem__(self, i):
if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])
if 'label' in self.data: if 'label' in self.data:
return self.graphs[i], self.data['label'][i] return g, self.data['label'][i]
else: else:
return self.graphs[i] return g
def __len__(self): def __len__(self):
return len(self.graphs) return len(self.graphs)
...@@ -49,6 +49,10 @@ class DGLDataset(object): ...@@ -49,6 +49,10 @@ class DGLDataset(object):
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose : bool verbose : bool
Whether to print out progress information Whether to print out progress information
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -72,13 +76,14 @@ class DGLDataset(object): ...@@ -72,13 +76,14 @@ class DGLDataset(object):
Hash value for the dataset and the setting. Hash value for the dataset and the setting.
""" """
def __init__(self, name, url=None, raw_dir=None, save_dir=None, def __init__(self, name, url=None, raw_dir=None, save_dir=None,
hash_key=(), force_reload=False, verbose=False): hash_key=(), force_reload=False, verbose=False, transform=None):
self._name = name self._name = name
self._url = url self._url = url
self._force_reload = force_reload self._force_reload = force_reload
self._verbose = verbose self._verbose = verbose
self._hash_key = hash_key self._hash_key = hash_key
self._hash = self._get_hash() self._hash = self._get_hash()
self._transform = transform
# if no dir is provided, the default dgl download dir is used. # if no dir is provided, the default dgl download dir is used.
if raw_dir is None: if raw_dir is None:
...@@ -142,7 +147,7 @@ class DGLDataset(object): ...@@ -142,7 +147,7 @@ class DGLDataset(object):
def _download(self): def _download(self):
"""Download dataset by calling ``self.download()`` """Download dataset by calling ``self.download()``
if the dataset does not exists under ``self.raw_path``. if the dataset does not exists under ``self.raw_path``.
By default ``self.raw_path = os.path.join(self.raw_dir, self.name)`` By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``
One can overwrite ``raw_path()`` function to change the path. One can overwrite ``raw_path()`` function to change the path.
""" """
...@@ -161,7 +166,7 @@ class DGLDataset(object): ...@@ -161,7 +166,7 @@ class DGLDataset(object):
- If loadin process fails, re-download and process the dataset. - If loadin process fails, re-download and process the dataset.
else: else:
- Download the dataset if needed. - Download the dataset if needed.
- Process the dataset and build the dgl graph. - Process the dataset and build the dgl graph.
- Save the processed dataset into files. - Save the processed dataset into files.
...@@ -287,17 +292,23 @@ class DGLBuiltinDataset(DGLDataset): ...@@ -287,17 +292,23 @@ class DGLBuiltinDataset(DGLDataset):
from the same dataset class by comparing the hash values. from the same dataset class by comparing the hash values.
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: False Whether to print out progress information. Default: False
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
""" """
def __init__(self, name, url, raw_dir=None, hash_key=(), force_reload=False, verbose=False): def __init__(self, name, url, raw_dir=None, hash_key=(),
force_reload=False, verbose=False, transform=None):
super(DGLBuiltinDataset, self).__init__(name, super(DGLBuiltinDataset, self).__init__(name,
url=url, url=url,
raw_dir=raw_dir, raw_dir=raw_dir,
save_dir=None, save_dir=None,
hash_key=hash_key, hash_key=hash_key,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def download(self): def download(self):
r""" Automatically download data and extract it. r""" Automatically download data and extract it.
......
...@@ -76,6 +76,10 @@ class FakeNewsDataset(DGLBuiltinDataset): ...@@ -76,6 +76,10 @@ class FakeNewsDataset(DGLBuiltinDataset):
downloaded data or the directory that downloaded data or the directory that
already stores the input data. already stores the input data.
Default: ~/.dgl/ Default: ~/.dgl/
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -113,7 +117,7 @@ class FakeNewsDataset(DGLBuiltinDataset): ...@@ -113,7 +117,7 @@ class FakeNewsDataset(DGLBuiltinDataset):
'politifact': 'dataset/FakeNewsPOL.zip' 'politifact': 'dataset/FakeNewsPOL.zip'
} }
def __init__(self, name, feature_name, raw_dir=None): def __init__(self, name, feature_name, raw_dir=None, transform=None):
assert name in ['gossipcop', 'politifact'], \ assert name in ['gossipcop', 'politifact'], \
"Only supports 'gossipcop' or 'politifact'." "Only supports 'gossipcop' or 'politifact'."
url = _get_dgl_url(self.file_urls[name]) url = _get_dgl_url(self.file_urls[name])
...@@ -123,7 +127,8 @@ class FakeNewsDataset(DGLBuiltinDataset): ...@@ -123,7 +127,8 @@ class FakeNewsDataset(DGLBuiltinDataset):
self.feature_name = feature_name self.feature_name = feature_name
super(FakeNewsDataset, self).__init__(name=name, super(FakeNewsDataset, self).__init__(name=name,
url=url, url=url,
raw_dir=raw_dir) raw_dir=raw_dir,
transform=transform)
def process(self): def process(self):
"""process raw data to graph, labels and masks""" """process raw data to graph, labels and masks"""
...@@ -213,7 +218,11 @@ class FakeNewsDataset(DGLBuiltinDataset): ...@@ -213,7 +218,11 @@ class FakeNewsDataset(DGLBuiltinDataset):
------- -------
(:class:`dgl.DGLGraph`, Tensor) (:class:`dgl.DGLGraph`, Tensor)
""" """
return self.graphs[i], self.labels[i] if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])
return g, self.labels[i]
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset. r"""Number of graphs in the dataset.
......
...@@ -48,8 +48,12 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -48,8 +48,12 @@ class FraudDataset(DGLBuiltinDataset):
Default: 0.1 Default: 0.1
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -88,9 +92,9 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -88,9 +92,9 @@ class FraudDataset(DGLBuiltinDataset):
'yelp': 'review', 'yelp': 'review',
'amazon': 'user' 'amazon': 'user'
} }
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7, def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True): val_size=0.1, force_reload=False, verbose=True, transform=None):
assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'" assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'"
url = _get_dgl_url(self.file_urls[name]) url = _get_dgl_url(self.file_urls[name])
self.seed = random_seed self.seed = random_seed
...@@ -101,30 +105,31 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -101,30 +105,31 @@ class FraudDataset(DGLBuiltinDataset):
raw_dir=raw_dir, raw_dir=raw_dir,
hash_key=(random_seed, train_size, val_size), hash_key=(random_seed, train_size, val_size),
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
"""process raw data to graph, labels, splitting masks""" """process raw data to graph, labels, splitting masks"""
file_path = os.path.join(self.raw_path, self.file_names[self.name]) file_path = os.path.join(self.raw_path, self.file_names[self.name])
data = io.loadmat(file_path) data = io.loadmat(file_path)
node_features = data['features'].todense() node_features = data['features'].todense()
# remove additional dimension of length 1 in raw .mat file # remove additional dimension of length 1 in raw .mat file
node_labels = data['label'].squeeze() node_labels = data['label'].squeeze()
graph_data = {} graph_data = {}
for relation in self.relations[self.name]: for relation in self.relations[self.name]:
adj = data[relation].tocoo() adj = data[relation].tocoo()
row, col = adj.row, adj.col row, col = adj.row, adj.col
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col) graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col)
g = heterograph(graph_data) g = heterograph(graph_data)
g.ndata['feature'] = F.tensor(node_features, dtype=F.data_type_dict['float32']) g.ndata['feature'] = F.tensor(node_features, dtype=F.data_type_dict['float32'])
g.ndata['label'] = F.tensor(node_labels, dtype=F.data_type_dict['int64']) g.ndata['label'] = F.tensor(node_labels, dtype=F.data_type_dict['int64'])
self.graph = g self.graph = g
self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size) self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size)
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph object r""" Get graph object
...@@ -145,12 +150,15 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -145,12 +150,15 @@ class FraudDataset(DGLBuiltinDataset):
- ``ndata['test_mask']``: mask of testing set - ``ndata['test_mask']``: mask of testing set
""" """
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
return self.graph if self._transform is None:
return self.graph
else:
return self._transform(self.graph)
def __len__(self): def __len__(self):
"""number of data examples""" """number of data examples"""
return len(self.graph) return len(self.graph)
@property @property
def num_classes(self): def num_classes(self):
"""Number of classes. """Number of classes.
...@@ -160,37 +168,37 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -160,37 +168,37 @@ class FraudDataset(DGLBuiltinDataset):
int int
""" """
return 2 return 2
def save(self): def save(self):
"""save processed data to directory `self.save_path`""" """save processed data to directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash)) graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graph) save_graphs(str(graph_path), self.graph)
def load(self): def load(self):
"""load processed data from directory `self.save_path`""" """load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash)) graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
graph_list, _ = load_graphs(str(graph_path)) graph_list, _ = load_graphs(str(graph_path))
g = graph_list[0] g = graph_list[0]
self.graph = g self.graph = g
def has_cache(self): def has_cache(self):
"""check whether there are processed data in `self.save_path`""" """check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash)) graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
return os.path.exists(graph_path) return os.path.exists(graph_path)
def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1): def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):
"""split the dataset into training set, validation set and testing set""" """split the dataset into training set, validation set and testing set"""
assert 0 <= train_size + val_size <= 1, \ assert 0 <= train_size + val_size <= 1, \
"The sum of valid training set size and validation set size " \ "The sum of valid training set size and validation set size " \
"must between 0 and 1 (inclusive)." "must between 0 and 1 (inclusive)."
N = x.shape[0] N = x.shape[0]
index = np.arange(N) index = np.arange(N)
if self.name == 'amazon': if self.name == 'amazon':
# 0-3304 are unlabeled nodes # 0-3304 are unlabeled nodes
index = np.arange(3305, N) index = np.arange(3305, N)
index = np.random.RandomState(seed).permutation(index) index = np.random.RandomState(seed).permutation(index)
train_idx = index[:int(train_size * len(index))] train_idx = index[:int(train_size * len(index))]
val_idx = index[len(index) - int(val_size * len(index)):] val_idx = index[len(index) - int(val_size * len(index)):]
...@@ -254,8 +262,12 @@ class FraudYelpDataset(FraudDataset): ...@@ -254,8 +262,12 @@ class FraudYelpDataset(FraudDataset):
Default: 0.1 Default: 0.1
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Examples Examples
-------- --------
...@@ -265,16 +277,17 @@ class FraudYelpDataset(FraudDataset): ...@@ -265,16 +277,17 @@ class FraudYelpDataset(FraudDataset):
>>> feat = graph.ndata['feature'] >>> feat = graph.ndata['feature']
>>> label = graph.ndata['label'] >>> label = graph.ndata['label']
""" """
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True): val_size=0.1, force_reload=False, verbose=True, transform=None):
super(FraudYelpDataset, self).__init__(name='yelp', super(FraudYelpDataset, self).__init__(name='yelp',
raw_dir=raw_dir, raw_dir=raw_dir,
random_seed=random_seed, random_seed=random_seed,
train_size=train_size, train_size=train_size,
val_size=val_size, val_size=val_size,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
class FraudAmazonDataset(FraudDataset): class FraudAmazonDataset(FraudDataset):
...@@ -330,8 +343,12 @@ class FraudAmazonDataset(FraudDataset): ...@@ -330,8 +343,12 @@ class FraudAmazonDataset(FraudDataset):
Default: 0.1 Default: 0.1
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Examples Examples
-------- --------
...@@ -341,13 +358,14 @@ class FraudAmazonDataset(FraudDataset): ...@@ -341,13 +358,14 @@ class FraudAmazonDataset(FraudDataset):
>>> feat = graph.ndata['feature'] >>> feat = graph.ndata['feature']
>>> label = graph.ndata['label'] >>> label = graph.ndata['label']
""" """
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7, def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True): val_size=0.1, force_reload=False, verbose=True, transform=None):
super(FraudAmazonDataset, self).__init__(name='amazon', super(FraudAmazonDataset, self).__init__(name='amazon',
raw_dir=raw_dir, raw_dir=raw_dir,
random_seed=random_seed, random_seed=random_seed,
train_size=train_size, train_size=train_size,
val_size=val_size, val_size=val_size,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
...@@ -37,8 +37,12 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -37,8 +37,12 @@ class GDELTDataset(DGLBuiltinDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -65,7 +69,8 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -65,7 +69,8 @@ class GDELTDataset(DGLBuiltinDataset):
.... ....
>>> >>>
""" """
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False): def __init__(self, mode='train', raw_dir=None,
force_reload=False, verbose=False, transform=None):
mode = mode.lower() mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid." assert mode in ['train', 'valid', 'test'], "Mode not valid."
self.mode = mode self.mode = mode
...@@ -75,7 +80,8 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -75,7 +80,8 @@ class GDELTDataset(DGLBuiltinDataset):
url=_url, url=_url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
file_path = os.path.join(self.raw_path, self.mode + '.txt') file_path = os.path.join(self.raw_path, self.mode + '.txt')
...@@ -148,6 +154,8 @@ class GDELTDataset(DGLBuiltinDataset): ...@@ -148,6 +154,8 @@ class GDELTDataset(DGLBuiltinDataset):
rate = self.data[row_mask][:, 1] rate = self.data[row_mask][:, 1]
g = dgl_graph((edges[:, 0], edges[:, 1])) g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64']) g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64'])
if self._transform is not None:
g = self._transform(g)
return g return g
def __len__(self): def __len__(self):
......
...@@ -18,9 +18,9 @@ from ..convert import graph as dgl_graph ...@@ -18,9 +18,9 @@ from ..convert import graph as dgl_graph
class GINDataset(DGLBuiltinDataset): class GINDataset(DGLBuiltinDataset):
"""Dataset Class for `How Powerful Are Graph Neural Networks? <https://arxiv.org/abs/1810.00826>`_. """Dataset Class for `How Powerful Are Graph Neural Networks? <https://arxiv.org/abs/1810.00826>`_.
This is adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_. This is adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.
The class provides an interface for nine datasets used in the paper along with the paper-specific The class provides an interface for nine datasets used in the paper along with the paper-specific
settings. The datasets are ``'MUTAG'``, ``'COLLAB'``, ``'IMDBBINARY'``, ``'IMDBMULTI'``, settings. The datasets are ``'MUTAG'``, ``'COLLAB'``, ``'IMDBBINARY'``, ``'IMDBMULTI'``,
``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, ``'REDDITBINARY'``, ``'REDDITMULTI5K'``. ``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, ``'REDDITBINARY'``, ``'REDDITMULTI5K'``.
...@@ -44,6 +44,10 @@ class GINDataset(DGLBuiltinDataset): ...@@ -44,6 +44,10 @@ class GINDataset(DGLBuiltinDataset):
add self to self edge if true add self to self edge if true
degree_as_nlabel: bool degree_as_nlabel: bool
take node degree as label and feature if true take node degree as label and feature if true
transform: callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Examples Examples
-------- --------
...@@ -73,7 +77,7 @@ class GINDataset(DGLBuiltinDataset): ...@@ -73,7 +77,7 @@ class GINDataset(DGLBuiltinDataset):
""" """
def __init__(self, name, self_loop, degree_as_nlabel=False, def __init__(self, name, self_loop, degree_as_nlabel=False,
raw_dir=None, force_reload=False, verbose=False): raw_dir=None, force_reload=False, verbose=False, transform=None):
self._name = name # MUTAG self._name = name # MUTAG
gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip' gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
...@@ -106,7 +110,8 @@ class GINDataset(DGLBuiltinDataset): ...@@ -106,7 +110,8 @@ class GINDataset(DGLBuiltinDataset):
self.nlabels_flag = False self.nlabels_flag = False
super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel), super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel),
raw_dir=raw_dir, force_reload=force_reload, verbose=verbose) raw_dir=raw_dir, force_reload=force_reload,
verbose=verbose, transform=transform)
@property @property
def raw_path(self): def raw_path(self):
...@@ -136,7 +141,11 @@ class GINDataset(DGLBuiltinDataset): ...@@ -136,7 +141,11 @@ class GINDataset(DGLBuiltinDataset):
(:class:`dgl.Graph`, Tensor) (:class:`dgl.Graph`, Tensor)
The graph and its label. The graph and its label.
""" """
return self.graphs[idx], self.labels[idx] if self._transform is None:
g = self.graphs[idx]
else:
g = self._transform(self.graphs[idx])
return g, self.labels[idx]
def _file_path(self): def _file_path(self):
return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name)) return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name))
......
...@@ -27,13 +27,14 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -27,13 +27,14 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
Reference: https://github.com/shchur/gnn-benchmark#datasets Reference: https://github.com/shchur/gnn-benchmark#datasets
""" """
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False): def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None):
_url = _get_dgl_url('dataset/' + name + '.zip') _url = _get_dgl_url('dataset/' + name + '.zip')
super(GNNBenchmarkDataset, self).__init__(name=name, super(GNNBenchmarkDataset, self).__init__(name=name,
url=_url, url=_url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
npz_path = os.path.join(self.raw_path, self.name + '.npz') npz_path = os.path.join(self.raw_path, self.name + '.npz')
...@@ -128,7 +129,10 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -128,7 +129,10 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
- ``ndata['label']``: node labels - ``ndata['label']``: node labels
""" """
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
return self._graph if self._transform is None:
return self._graph
else:
return self._transform(self._graph)
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset""" r"""Number of graphs in the dataset"""
...@@ -164,8 +168,12 @@ class CoraFullDataset(GNNBenchmarkDataset): ...@@ -164,8 +168,12 @@ class CoraFullDataset(GNNBenchmarkDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -182,11 +190,12 @@ class CoraFullDataset(GNNBenchmarkDataset): ...@@ -182,11 +190,12 @@ class CoraFullDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False): def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoraFullDataset, self).__init__(name="cora_full", super(CoraFullDataset, self).__init__(name="cora_full",
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
@property @property
def num_classes(self): def num_classes(self):
...@@ -231,8 +240,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset): ...@@ -231,8 +240,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -249,11 +262,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset): ...@@ -249,11 +262,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False): def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoauthorCSDataset, self).__init__(name='coauthor_cs', super(CoauthorCSDataset, self).__init__(name='coauthor_cs',
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
@property @property
def num_classes(self): def num_classes(self):
...@@ -298,8 +312,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset): ...@@ -298,8 +312,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -316,11 +334,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset): ...@@ -316,11 +334,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False): def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoauthorPhysicsDataset, self).__init__(name='coauthor_physics', super(CoauthorPhysicsDataset, self).__init__(name='coauthor_physics',
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
@property @property
def num_classes(self): def num_classes(self):
...@@ -364,8 +383,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset): ...@@ -364,8 +383,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -382,11 +405,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset): ...@@ -382,11 +405,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False): def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(AmazonCoBuyComputerDataset, self).__init__(name='amazon_co_buy_computer', super(AmazonCoBuyComputerDataset, self).__init__(name='amazon_co_buy_computer',
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
@property @property
def num_classes(self): def num_classes(self):
...@@ -430,8 +454,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset): ...@@ -430,8 +454,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -448,11 +476,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset): ...@@ -448,11 +476,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False): def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(AmazonCoBuyPhotoDataset, self).__init__(name='amazon_co_buy_photo', super(AmazonCoBuyPhotoDataset, self).__init__(name='amazon_co_buy_photo',
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
@property @property
def num_classes(self): def num_classes(self):
......
...@@ -39,8 +39,12 @@ class ICEWS18Dataset(DGLBuiltinDataset): ...@@ -39,8 +39,12 @@ class ICEWS18Dataset(DGLBuiltinDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
------- -------
...@@ -61,7 +65,7 @@ class ICEWS18Dataset(DGLBuiltinDataset): ...@@ -61,7 +65,7 @@ class ICEWS18Dataset(DGLBuiltinDataset):
.... ....
>>> >>>
""" """
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False): def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None):
mode = mode.lower() mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid" assert mode in ['train', 'valid', 'test'], "Mode not valid"
self.mode = mode self.mode = mode
...@@ -70,7 +74,8 @@ class ICEWS18Dataset(DGLBuiltinDataset): ...@@ -70,7 +74,8 @@ class ICEWS18Dataset(DGLBuiltinDataset):
url=_url, url=_url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
data = loadtxt(os.path.join(self.save_path, '{}.txt'.format(self.mode)), data = loadtxt(os.path.join(self.save_path, '{}.txt'.format(self.mode)),
...@@ -118,7 +123,10 @@ class ICEWS18Dataset(DGLBuiltinDataset): ...@@ -118,7 +123,10 @@ class ICEWS18Dataset(DGLBuiltinDataset):
- ``edata['rel_type']``: edge type - ``edata['rel_type']``: edge type
""" """
return self._graphs[idx] if self._transform is None:
return self._graphs[idx]
else:
return self._transform(self._graphs[idx])
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset. r"""Number of graphs in the dataset.
......
...@@ -34,6 +34,13 @@ class KarateClubDataset(DGLDataset): ...@@ -34,6 +34,13 @@ class KarateClubDataset(DGLDataset):
- Edges: 156 - Edges: 156
- Number of Classes: 2 - Number of Classes: 2
Parameters
----------
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
num_classes : int num_classes : int
...@@ -48,8 +55,8 @@ class KarateClubDataset(DGLDataset): ...@@ -48,8 +55,8 @@ class KarateClubDataset(DGLDataset):
>>> g = dataset[0] >>> g = dataset[0]
>>> labels = g.ndata['label'] >>> labels = g.ndata['label']
""" """
def __init__(self): def __init__(self, transform=None):
super(KarateClubDataset, self).__init__(name='karate_club') super(KarateClubDataset, self).__init__(name='karate_club', transform=transform)
def process(self): def process(self):
kc_graph = nx.karate_club_graph() kc_graph = nx.karate_club_graph()
...@@ -88,7 +95,10 @@ class KarateClubDataset(DGLDataset): ...@@ -88,7 +95,10 @@ class KarateClubDataset(DGLDataset):
- ``ndata['label']``: ground truth labels - ``ndata['label']``: ground truth labels
""" """
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
return self._graph if self._transform is None:
return self._graph
else:
return self._transform(self._graph)
def __len__(self): def __len__(self):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
......
...@@ -25,19 +25,24 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -25,19 +25,24 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
Parameters Parameters
----------- -----------
name: str name : str
Name can be 'FB15k-237', 'FB15k' or 'wn18'. Name can be 'FB15k-237', 'FB15k' or 'wn18'.
reverse: bool reverse : bool
Whether add reverse edges. Default: True. Whether add reverse edges. Default: True.
raw_dir : str raw_dir : str
Raw file directory to download/contains the input data directory. Raw file directory to download/contains the input data directory.
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
""" """
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True): def __init__(self, name, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
self._name = name self._name = name
self.reverse = reverse self.reverse = reverse
url = _get_dgl_url('dataset/') + '{}.tgz'.format(name) url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
...@@ -45,7 +50,8 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -45,7 +50,8 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
url=url, url=url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def download(self): def download(self):
r""" Automatically download data and extract it. r""" Automatically download data and extract it.
...@@ -112,7 +118,10 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -112,7 +118,10 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph" assert idx == 0, "This dataset has only one graph"
return self._g if self._transform is None:
return self._g
else:
return self._transform(self._g)
def __len__(self): def __len__(self):
return 1 return 1
...@@ -389,8 +398,12 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -389,8 +398,12 @@ class FB15k237Dataset(KnowledgeGraphDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -433,9 +446,11 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -433,9 +446,11 @@ class FB15k237Dataset(KnowledgeGraphDataset):
>>> >>>
>>> # Train, Validation and Test >>> # Train, Validation and Test
""" """
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True): def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
name = 'FB15k-237' name = 'FB15k-237'
super(FB15k237Dataset, self).__init__(name, reverse, raw_dir, force_reload, verbose) super(FB15k237Dataset, self).__init__(name, reverse, raw_dir,
force_reload, verbose, transform)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -526,8 +541,12 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -526,8 +541,12 @@ class FB15kDataset(KnowledgeGraphDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -570,9 +589,11 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -570,9 +589,11 @@ class FB15kDataset(KnowledgeGraphDataset):
>>> # Train, Validation and Test >>> # Train, Validation and Test
>>> >>>
""" """
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True): def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
name = 'FB15k' name = 'FB15k'
super(FB15kDataset, self).__init__(name, reverse, raw_dir, force_reload, verbose) super(FB15kDataset, self).__init__(name, reverse, raw_dir,
force_reload, verbose, transform)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -662,8 +683,12 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -662,8 +683,12 @@ class WN18Dataset(KnowledgeGraphDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -706,9 +731,11 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -706,9 +731,11 @@ class WN18Dataset(KnowledgeGraphDataset):
>>> # Train, Validation and Test >>> # Train, Validation and Test
>>> >>>
""" """
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True): def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
name = 'wn18' name = 'wn18'
super(WN18Dataset, self).__init__(name, reverse, raw_dir, force_reload, verbose) super(WN18Dataset, self).__init__(name, reverse, raw_dir,
force_reload, verbose, transform)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
......
...@@ -33,8 +33,12 @@ class MiniGCDataset(DGLDataset): ...@@ -33,8 +33,12 @@ class MiniGCDataset(DGLDataset):
Minimum number of nodes for graphs Minimum number of nodes for graphs
max_num_v: int max_num_v: int
Maximum number of nodes for graphs Maximum number of nodes for graphs
seed : int, default is 0 seed: int, default is 0
Random seed for data generation Random seed for data generation
transform: callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -75,7 +79,7 @@ class MiniGCDataset(DGLDataset): ...@@ -75,7 +79,7 @@ class MiniGCDataset(DGLDataset):
""" """
def __init__(self, num_graphs, min_num_v, max_num_v, seed=0, def __init__(self, num_graphs, min_num_v, max_num_v, seed=0,
save_graph=True, force_reload=False, verbose=False): save_graph=True, force_reload=False, verbose=False, transform=None):
self.num_graphs = num_graphs self.num_graphs = num_graphs
self.min_num_v = min_num_v self.min_num_v = min_num_v
self.max_num_v = max_num_v self.max_num_v = max_num_v
...@@ -84,7 +88,7 @@ class MiniGCDataset(DGLDataset): ...@@ -84,7 +88,7 @@ class MiniGCDataset(DGLDataset):
super(MiniGCDataset, self).__init__(name="minigc", hash_key=(num_graphs, min_num_v, max_num_v, seed), super(MiniGCDataset, self).__init__(name="minigc", hash_key=(num_graphs, min_num_v, max_num_v, seed),
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose, transform=transform)
def process(self): def process(self):
self.graphs = [] self.graphs = []
...@@ -108,7 +112,11 @@ class MiniGCDataset(DGLDataset): ...@@ -108,7 +112,11 @@ class MiniGCDataset(DGLDataset):
(:class:`dgl.Graph`, Tensor) (:class:`dgl.Graph`, Tensor)
The graph and its label. The graph and its label.
""" """
return self.graphs[idx], self.labels[idx] if self._transform is None:
g = self.graphs[idx]
else:
g = self._transform(self.graphs[idx])
return g, self.labels[idx]
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash)) graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
......
...@@ -56,9 +56,13 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -56,9 +56,13 @@ class PPIDataset(DGLBuiltinDataset):
force_reload : bool force_reload : bool
Whether to reload the dataset. Whether to reload the dataset.
Default: False Default: False
verbose: bool verbose : bool
Whether to print out progress information. Whether to print out progress information.
Default: True. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -79,7 +83,8 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -79,7 +83,8 @@ class PPIDataset(DGLBuiltinDataset):
.... # your code here .... # your code here
>>> >>>
""" """
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False): def __init__(self, mode='train', raw_dir=None, force_reload=False,
verbose=False, transform=None):
assert mode in ['train', 'valid', 'test'] assert mode in ['train', 'valid', 'test']
self.mode = mode self.mode = mode
_url = _get_dgl_url('dataset/ppi.zip') _url = _get_dgl_url('dataset/ppi.zip')
...@@ -87,7 +92,8 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -87,7 +92,8 @@ class PPIDataset(DGLBuiltinDataset):
url=_url, url=_url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
graph_file = os.path.join(self.save_path, '{}_graph.json'.format(self.mode)) graph_file = os.path.join(self.save_path, '{}_graph.json'.format(self.mode))
...@@ -178,7 +184,10 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -178,7 +184,10 @@ class PPIDataset(DGLBuiltinDataset):
- ``ndata['feat']``: node features - ``ndata['feat']``: node features
- ``ndata['label']``: node labels - ``ndata['label']``: node labels
""" """
return self.graphs[item] if self._transform is None:
return self.graphs[item]
else:
return self._transform(self.graphs[item])
class LegacyPPIDataset(PPIDataset): class LegacyPPIDataset(PPIDataset):
...@@ -198,5 +207,8 @@ class LegacyPPIDataset(PPIDataset): ...@@ -198,5 +207,8 @@ class LegacyPPIDataset(PPIDataset):
(dgl.DGLGraph, Tensor, Tensor) (dgl.DGLGraph, Tensor, Tensor)
The graph, features and its label. The graph, features and its label.
""" """
if self._transform is None:
return self.graphs[item], self.graphs[item].ndata['feat'], self.graphs[item].ndata['label'] g = self.graphs[item]
else:
g = self._transform(self.graphs[item])
return g, g.ndata['feat'], g.ndata['label']
...@@ -34,8 +34,12 @@ class QM7bDataset(DGLDataset): ...@@ -34,8 +34,12 @@ class QM7bDataset(DGLDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -65,12 +69,13 @@ class QM7bDataset(DGLDataset): ...@@ -65,12 +69,13 @@ class QM7bDataset(DGLDataset):
'datasets/qm7b.mat' 'datasets/qm7b.mat'
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392' _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
def __init__(self, raw_dir=None, force_reload=False, verbose=False): def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(QM7bDataset, self).__init__(name='qm7b', super(QM7bDataset, self).__init__(name='qm7b',
url=self._url, url=self._url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
mat_path = self.raw_path + '.mat' mat_path = self.raw_path + '.mat'
...@@ -129,7 +134,11 @@ class QM7bDataset(DGLDataset): ...@@ -129,7 +134,11 @@ class QM7bDataset(DGLDataset):
------- -------
(:class:`dgl.DGLGraph`, Tensor) (:class:`dgl.DGLGraph`, Tensor)
""" """
return self.graphs[idx], self.label[idx] if self._transform is None:
g = self.graphs[idx]
else:
g = self._transform(self.graphs[idx])
return g, self.label[idx]
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset. r"""Number of graphs in the dataset.
......
...@@ -20,11 +20,11 @@ class QM9Dataset(DGLDataset): ...@@ -20,11 +20,11 @@ class QM9Dataset(DGLDataset):
2. It only provides atoms' coordinates and atomic numbers as node features 2. It only provides atoms' coordinates and atomic numbers as node features
3. It only provides 12 regression targets. 3. It only provides 12 regression targets.
Reference: Reference:
- `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_, - `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_,
- `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_ - `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics: Statistics:
- Number of graphs: 130,831 - Number of graphs: 130,831
...@@ -60,9 +60,9 @@ class QM9Dataset(DGLDataset): ...@@ -60,9 +60,9 @@ class QM9Dataset(DGLDataset):
Parameters Parameters
---------- ----------
label_keys: list label_keys : list
Names of the regression property, which should be a subset of the keys in the table above. Names of the regression property, which should be a subset of the keys in the table above.
cutoff: float cutoff : float
Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this. Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.
Default: 5.0 Angstrom Default: 5.0 Angstrom
raw_dir : str raw_dir : str
...@@ -70,8 +70,12 @@ class QM9Dataset(DGLDataset): ...@@ -70,8 +70,12 @@ class QM9Dataset(DGLDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -82,7 +86,7 @@ class QM9Dataset(DGLDataset): ...@@ -82,7 +86,7 @@ class QM9Dataset(DGLDataset):
------ ------
UserWarning UserWarning
If the raw data is changed in the remote server by the author. If the raw data is changed in the remote server by the author.
Examples Examples
-------- --------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0) >>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
...@@ -102,8 +106,9 @@ class QM9Dataset(DGLDataset): ...@@ -102,8 +106,9 @@ class QM9Dataset(DGLDataset):
cutoff=5.0, cutoff=5.0,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=False): verbose=False,
transform=None):
self.cutoff = cutoff self.cutoff = cutoff
self.label_keys = label_keys self.label_keys = label_keys
self._url = _get_dgl_url('dataset/qm9_eV.npz') self._url = _get_dgl_url('dataset/qm9_eV.npz')
...@@ -112,7 +117,8 @@ class QM9Dataset(DGLDataset): ...@@ -112,7 +117,8 @@ class QM9Dataset(DGLDataset):
url=self._url, url=self._url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
npz_path = f'{self.raw_dir}/qm9_eV.npz' npz_path = f'{self.raw_dir}/qm9_eV.npz'
...@@ -148,7 +154,7 @@ class QM9Dataset(DGLDataset): ...@@ -148,7 +154,7 @@ class QM9Dataset(DGLDataset):
---------- ----------
idx : int idx : int
Item index Item index
Returns Returns
------- -------
dgl.DGLGraph dgl.DGLGraph
...@@ -170,8 +176,12 @@ class QM9Dataset(DGLDataset): ...@@ -170,8 +176,12 @@ class QM9Dataset(DGLDataset):
g = dgl_graph((u, v)) g = dgl_graph((u, v))
g = to_bidirected(g) g = to_bidirected(g)
g.ndata['R'] = F.tensor(R, dtype=F.data_type_dict['float32']) g.ndata['R'] = F.tensor(R, dtype=F.data_type_dict['float32'])
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]], g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
dtype=F.data_type_dict['int64']) dtype=F.data_type_dict['int64'])
if self._transform is not None:
g = self._transform(g)
return g, label return g, label
def __len__(self): def __len__(self):
......
...@@ -14,34 +14,34 @@ class QM9EdgeDataset(DGLDataset): ...@@ -14,34 +14,34 @@ class QM9EdgeDataset(DGLDataset):
This dataset consists of 130,831 molecules with 19 regression targets. This dataset consists of 130,831 molecules with 19 regression targets.
Nodes correspond to atoms and edges correspond to bonds. Nodes correspond to atoms and edges correspond to bonds.
This dataset differs from :class:`~dgl.data.QM9Dataset` in the following aspects: This dataset differs from :class:`~dgl.data.QM9Dataset` in the following aspects:
1. It includes the bonds in a molecule in the edges of the corresponding graph while the edges in :class:`~dgl.data.QM9Dataset` are purely distance-based. 1. It includes the bonds in a molecule in the edges of the corresponding graph while the edges in :class:`~dgl.data.QM9Dataset` are purely distance-based.
2. It provides edge features, and node features in addition to the atoms' coordinates and atomic numbers. 2. It provides edge features, and node features in addition to the atoms' coordinates and atomic numbers.
3. It provides another 7 regression tasks(from 12 to 19). 3. It provides another 7 regression tasks(from 12 to 19).
This class is built based on a preprocessed version of the dataset, and we provide the preprocessing datails `here <https://gist.github.com/hengruizhang98/a2da30213b2356fff18b25385c9d3cd2>`_. This class is built based on a preprocessed version of the dataset, and we provide the preprocessing datails `here <https://gist.github.com/hengruizhang98/a2da30213b2356fff18b25385c9d3cd2>`_.
Reference: Reference:
- `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_ - `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_
- `"Neural Message Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_ - `"Neural Message Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_
For For
Statistics: Statistics:
- Number of graphs: 130,831. - Number of graphs: 130,831.
- Number of regression targets: 19. - Number of regression targets: 19.
Node attributes: Node attributes:
- pos: the 3D coordinates of each atom. - pos: the 3D coordinates of each atom.
- attr: the 11D atom features. - attr: the 11D atom features.
Edge attributes: Edge attributes:
- edge_attr: the 4D bond features. - edge_attr: the 4D bond features.
Regression targets: Regression targets:
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
...@@ -85,10 +85,10 @@ class QM9EdgeDataset(DGLDataset): ...@@ -85,10 +85,10 @@ class QM9EdgeDataset(DGLDataset):
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| C | :math:`C` | Rotational constant | :math:`\textrm{GHz}` | | C | :math:`C` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Parameters Parameters
---------- ----------
label_keys: list label_keys : list
Names of the regression property, which should be a subset of the keys in the table above. Names of the regression property, which should be a subset of the keys in the table above.
If not provided, it will load all the labels. If not provided, it will load all the labels.
raw_dir : str raw_dir : str
...@@ -96,8 +96,12 @@ class QM9EdgeDataset(DGLDataset): ...@@ -96,8 +96,12 @@ class QM9EdgeDataset(DGLDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False. Whether to reload the dataset. Default: False.
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -108,13 +112,13 @@ class QM9EdgeDataset(DGLDataset): ...@@ -108,13 +112,13 @@ class QM9EdgeDataset(DGLDataset):
------ ------
UserWarning UserWarning
If the raw data is changed in the remote server by the author. If the raw data is changed in the remote server by the author.
Examples Examples
-------- --------
>>> data = QM9EdgeDataset(label_keys=['mu', 'alpha']) >>> data = QM9EdgeDataset(label_keys=['mu', 'alpha'])
>>> data.num_labels >>> data.num_labels
2 2
>>> # iterate over the dataset >>> # iterate over the dataset
>>> for graph, labels in data: >>> for graph, labels in data:
... print(graph) # get information of each graph ... print(graph) # get information of each graph
...@@ -122,47 +126,49 @@ class QM9EdgeDataset(DGLDataset): ...@@ -122,47 +126,49 @@ class QM9EdgeDataset(DGLDataset):
... # your code here... ... # your code here...
>>> >>>
""" """
keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'U0_atom', keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'U0_atom',
'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'] 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C']
map_dict = {} map_dict = {}
for i, key in enumerate(keys): for i, key in enumerate(keys):
map_dict[key] = i map_dict[key] = i
def __init__(self, def __init__(self,
label_keys=None, label_keys=None,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True,
transform=None):
if label_keys is None: if label_keys is None:
self.label_keys = None self.label_keys = None
self.num_labels = 19 self.num_labels = 19
else: else:
self.label_keys = [self.map_dict[i] for i in label_keys] self.label_keys = [self.map_dict[i] for i in label_keys]
self.num_labels = len(label_keys) self.num_labels = len(label_keys)
self._url = _get_dgl_url('dataset/qm9_edge.npz') self._url = _get_dgl_url('dataset/qm9_edge.npz')
super(QM9EdgeDataset, self).__init__(name='qm9Edge', super(QM9EdgeDataset, self).__init__(name='qm9Edge',
raw_dir=raw_dir, raw_dir=raw_dir,
url=self._url, url=self._url,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def download(self): def download(self):
file_path = f'{self.raw_dir}/qm9_edge.npz' file_path = f'{self.raw_dir}/qm9_edge.npz'
if not os.path.exists(file_path): if not os.path.exists(file_path):
download(self._url, path=file_path) download(self._url, path=file_path)
def process(self): def process(self):
self.load() self.load()
def has_cache(self): def has_cache(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz' npz_path = f'{self.raw_dir}/qm9_edge.npz'
return os.path.exists(npz_path) return os.path.exists(npz_path)
def save(self): def save(self):
np.savez_compressed(f'{self.raw_dir}/qm9_edge.npz', np.savez_compressed(f'{self.raw_dir}/qm9_edge.npz',
n_node=self.n_node, n_node=self.n_node,
...@@ -171,7 +177,7 @@ class QM9EdgeDataset(DGLDataset): ...@@ -171,7 +177,7 @@ class QM9EdgeDataset(DGLDataset):
node_pos=self.node_pos, node_pos=self.node_pos,
edge_attr=self.edge_attr, edge_attr=self.edge_attr,
src=self.src, src=self.src,
dst=self.dst, dst=self.dst,
targets=self.targets) targets=self.targets)
def load(self): def load(self):
...@@ -184,52 +190,55 @@ class QM9EdgeDataset(DGLDataset): ...@@ -184,52 +190,55 @@ class QM9EdgeDataset(DGLDataset):
self.node_pos = data_dict['node_pos'] self.node_pos = data_dict['node_pos']
self.edge_attr = data_dict['edge_attr'] self.edge_attr = data_dict['edge_attr']
self.targets = data_dict['targets'] self.targets = data_dict['targets']
self.src = data_dict['src'] self.src = data_dict['src']
self.dst = data_dict['dst'] self.dst = data_dict['dst']
self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)]) self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])
self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)]) self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph and label by index r""" Get graph and label by index
Parameters Parameters
---------- ----------
idx : int idx : int
Item index Item index
Returns Returns
------- -------
dgl.DGLGraph dgl.DGLGraph
The graph contains: The graph contains:
- ``ndata['pos']``: the coordinates of each atom - ``ndata['pos']``: the coordinates of each atom
- ``ndata['attr']``: the features of each atom - ``ndata['attr']``: the features of each atom
- ``edata['edge_attr']``: the features of each bond - ``edata['edge_attr']``: the features of each bond
Tensor Tensor
Property values of molecular graphs Property values of molecular graphs
""" """
pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]] pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]]
src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]] src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]] dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
g = dgl_graph((src, dst)) g = dgl_graph((src, dst))
g.ndata['pos'] = F.tensor(pos, dtype=F.data_type_dict['float32']) g.ndata['pos'] = F.tensor(pos, dtype=F.data_type_dict['float32'])
g.ndata['attr'] = F.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]], dtype=F.data_type_dict['float32']) g.ndata['attr'] = F.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
g.edata['edge_attr'] = F.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]], dtype=F.data_type_dict['float32']) g.edata['edge_attr'] = F.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
label = F.tensor(self.targets[idx][self.label_keys], dtype=F.data_type_dict['float32']) label = F.tensor(self.targets[idx][self.label_keys], dtype=F.data_type_dict['float32'])
if self._transform is not None:
g = self._transform(g)
return g, label return g, label
def __len__(self): def __len__(self):
r""" Number of graphs in the dataset. r""" Number of graphs in the dataset.
Returns Returns
------- -------
int int
......
...@@ -94,15 +94,20 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -94,15 +94,20 @@ class RDFGraphDataset(DGLBuiltinDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool, optional force_reload : bool, optional
If true, force load and process from raw data. Ignore cached pre-processed data. If true, force load and process from raw data. Ignore cached pre-processed data.
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
""" """
def __init__(self, name, url, predict_category, def __init__(self, name, url, predict_category,
print_every=10000, print_every=10000,
insert_reverse=True, insert_reverse=True,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True,
transform=None):
self._insert_reverse = insert_reverse self._insert_reverse = insert_reverse
self._print_every = print_every self._print_every = print_every
self._predict_category = predict_category self._predict_category = predict_category
...@@ -110,7 +115,8 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -110,7 +115,8 @@ class RDFGraphDataset(DGLBuiltinDataset):
super(RDFGraphDataset, self).__init__(name, url, super(RDFGraphDataset, self).__init__(name, url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
raw_tuples = self.load_raw_tuples(self.raw_path) raw_tuples = self.load_raw_tuples(self.raw_path)
...@@ -409,6 +415,8 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -409,6 +415,8 @@ class RDFGraphDataset(DGLBuiltinDataset):
r"""Gets the graph object r"""Gets the graph object
""" """
g = self._hg g = self._hg
if self._transform is not None:
g = self._transform(g)
return g return g
def __len__(self): def __len__(self):
...@@ -523,17 +531,21 @@ class AIFBDataset(RDFGraphDataset): ...@@ -523,17 +531,21 @@ class AIFBDataset(RDFGraphDataset):
Parameters Parameters
----------- -----------
print_every: int print_every : int
Preprocessing log for every X tuples. Default: 10000. Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True. If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str raw_dir : str
Raw file directory to download/contains the input data directory. Raw file directory to download/contains the input data directory.
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -562,7 +574,8 @@ class AIFBDataset(RDFGraphDataset): ...@@ -562,7 +574,8 @@ class AIFBDataset(RDFGraphDataset):
insert_reverse=True, insert_reverse=True,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True,
transform=None):
import rdflib as rdf import rdflib as rdf
self.employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs") self.employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
self.affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation") self.affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
...@@ -574,7 +587,8 @@ class AIFBDataset(RDFGraphDataset): ...@@ -574,7 +587,8 @@ class AIFBDataset(RDFGraphDataset):
insert_reverse=insert_reverse, insert_reverse=insert_reverse,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -653,17 +667,21 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -653,17 +667,21 @@ class MUTAGDataset(RDFGraphDataset):
Parameters Parameters
----------- -----------
print_every: int print_every : int
Preprocessing log for every X tuples. Default: 10000. Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True. If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str raw_dir : str
Raw file directory to download/contains the input data directory. Raw file directory to download/contains the input data directory.
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -697,7 +715,8 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -697,7 +715,8 @@ class MUTAGDataset(RDFGraphDataset):
insert_reverse=True, insert_reverse=True,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True,
transform=None):
import rdflib as rdf import rdflib as rdf
self.is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic") self.is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
self.rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type") self.rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
...@@ -712,7 +731,8 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -712,7 +731,8 @@ class MUTAGDataset(RDFGraphDataset):
insert_reverse=insert_reverse, insert_reverse=insert_reverse,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -814,17 +834,21 @@ class BGSDataset(RDFGraphDataset): ...@@ -814,17 +834,21 @@ class BGSDataset(RDFGraphDataset):
Parameters Parameters
----------- -----------
print_every: int print_every : int
Preprocessing log for every X tuples. Default: 10000. Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True. If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str raw_dir : str
Raw file directory to download/contains the input data directory. Raw file directory to download/contains the input data directory.
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -854,7 +878,8 @@ class BGSDataset(RDFGraphDataset): ...@@ -854,7 +878,8 @@ class BGSDataset(RDFGraphDataset):
insert_reverse=True, insert_reverse=True,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True,
transform=None):
import rdflib as rdf import rdflib as rdf
url = _get_dgl_url('dataset/rdf/bgs-hetero.zip') url = _get_dgl_url('dataset/rdf/bgs-hetero.zip')
name = 'bgs-hetero' name = 'bgs-hetero'
...@@ -865,7 +890,8 @@ class BGSDataset(RDFGraphDataset): ...@@ -865,7 +890,8 @@ class BGSDataset(RDFGraphDataset):
insert_reverse=insert_reverse, insert_reverse=insert_reverse,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -964,17 +990,21 @@ class AMDataset(RDFGraphDataset): ...@@ -964,17 +990,21 @@ class AMDataset(RDFGraphDataset):
Parameters Parameters
----------- -----------
print_every: int print_every : int
Preprocessing log for every X tuples. Default: 10000. Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True. If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str raw_dir : str
Raw file directory to download/contains the input data directory. Raw file directory to download/contains the input data directory.
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -1003,7 +1033,8 @@ class AMDataset(RDFGraphDataset): ...@@ -1003,7 +1033,8 @@ class AMDataset(RDFGraphDataset):
insert_reverse=True, insert_reverse=True,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True): verbose=True,
transform=None):
import rdflib as rdf import rdflib as rdf
self.objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory") self.objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
self.material = rdf.term.URIRef("http://purl.org/collections/nl/am/material") self.material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
...@@ -1015,7 +1046,8 @@ class AMDataset(RDFGraphDataset): ...@@ -1015,7 +1046,8 @@ class AMDataset(RDFGraphDataset):
insert_reverse=insert_reverse, insert_reverse=insert_reverse,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
......
...@@ -84,8 +84,12 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -84,8 +84,12 @@ class RedditDataset(DGLBuiltinDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -125,7 +129,8 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -125,7 +129,8 @@ class RedditDataset(DGLBuiltinDataset):
>>> >>>
>>> # Train, Validation and Test >>> # Train, Validation and Test
""" """
def __init__(self, self_loop=False, raw_dir=None, force_reload=False, verbose=False): def __init__(self, self_loop=False, raw_dir=None, force_reload=False,
verbose=False, transform=None):
self_loop_str = "" self_loop_str = ""
if self_loop: if self_loop:
self_loop_str = "_self_loop" self_loop_str = "_self_loop"
...@@ -135,7 +140,8 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -135,7 +140,8 @@ class RedditDataset(DGLBuiltinDataset):
url=_url, url=_url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
# graph # graph
...@@ -251,7 +257,10 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -251,7 +257,10 @@ class RedditDataset(DGLBuiltinDataset):
- ``ndata['test_mask']:`` mask for test node set - ``ndata['test_mask']:`` mask for test node set
""" """
assert idx == 0, "Reddit Dataset only has one graph" assert idx == 0, "Reddit Dataset only has one graph"
return self._graph if self._transform is None:
return self._graph
else:
return self._transform(self._graph)
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset""" r"""Number of graphs in the dataset"""
......
...@@ -23,7 +23,7 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -23,7 +23,7 @@ class SSTDataset(DGLBuiltinDataset):
r"""Stanford Sentiment Treebank dataset. r"""Stanford Sentiment Treebank dataset.
.. deprecated:: 0.5.0 .. deprecated:: 0.5.0
- ``trees`` is deprecated, it is replaced by: - ``trees`` is deprecated, it is replaced by:
>>> dataset = SSTDataset() >>> dataset = SSTDataset()
...@@ -63,8 +63,12 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -63,8 +63,12 @@ class SSTDataset(DGLBuiltinDataset):
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Default: False Whether to reload the dataset. Default: False
verbose: bool verbose : bool
Whether to print out progress information. Default: True. Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes Attributes
---------- ----------
...@@ -120,7 +124,8 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -120,7 +124,8 @@ class SSTDataset(DGLBuiltinDataset):
vocab_file=None, vocab_file=None,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=False): verbose=False,
transform=None):
assert mode in ['train', 'dev', 'test', 'tiny'] assert mode in ['train', 'dev', 'test', 'tiny']
_url = _get_dgl_url('dataset/sst.zip') _url = _get_dgl_url('dataset/sst.zip')
self._glove_embed_file = glove_embed_file if mode == 'train' else None self._glove_embed_file = glove_embed_file if mode == 'train' else None
...@@ -130,7 +135,8 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -130,7 +135,8 @@ class SSTDataset(DGLBuiltinDataset):
url=_url, url=_url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose) verbose=verbose,
transform=transform)
def process(self): def process(self):
from nltk.corpus.reader import BracketParseCorpusReader from nltk.corpus.reader import BracketParseCorpusReader
...@@ -255,7 +261,10 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -255,7 +261,10 @@ class SSTDataset(DGLBuiltinDataset):
- ``ndata['y']:`` label of the node - ``ndata['y']:`` label of the node
- ``ndata['mask']``: 1 if the node is a leaf, otherwise 0 - ``ndata['mask']``: 1 if the node is a leaf, otherwise 0
""" """
return self._trees[idx] if self._transform is None:
return self._trees[idx]
else:
return self._transform(self._trees[idx])
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset.""" r"""Number of graphs in the dataset."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment