Unverified Commit ab812179 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto reformat data/. (#5318)



* data

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent f0759a96
This diff is collapsed.
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
from .. import backend as F from .. import backend as F
from ..base import DGLError from ..base import DGLError
from .dgl_dataset import DGLDataset from .dgl_dataset import DGLDataset
from .utils import Subset, load_graphs, save_graphs from .utils import load_graphs, save_graphs, Subset
class CSVDataset(DGLDataset): class CSVDataset(DGLDataset):
......
...@@ -8,7 +8,7 @@ import pydantic as dt ...@@ -8,7 +8,7 @@ import pydantic as dt
import yaml import yaml
from .. import backend as F from .. import backend as F
from ..base import DGLError, dgl_warning from ..base import dgl_warning, DGLError
from ..convert import heterograph as dgl_heterograph from ..convert import heterograph as dgl_heterograph
......
...@@ -99,7 +99,6 @@ class GINDataset(DGLBuiltinDataset): ...@@ -99,7 +99,6 @@ class GINDataset(DGLBuiltinDataset):
verbose=False, verbose=False,
transform=None, transform=None,
): ):
self._name = name # MUTAG self._name = name # MUTAG
gin_url = "https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip" gin_url = "https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip"
self.ds_name = "nig" self.ds_name = "nig"
......
"""GNN Benchmark datasets for node classification.""" """GNN Benchmark datasets for node classification."""
import scipy.sparse as sp
import numpy as np
import os import os
from .dgl_dataset import DGLBuiltinDataset import numpy as np
from .utils import save_graphs, load_graphs, _get_dgl_url, deprecate_property, deprecate_class import scipy.sparse as sp
from .. import backend as F, transforms
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from .. import backend as F
from .. import transforms
__all__ = ["AmazonCoBuyComputerDataset", "AmazonCoBuyPhotoDataset", "CoauthorPhysicsDataset", "CoauthorCSDataset", from .dgl_dataset import DGLBuiltinDataset
"CoraFullDataset", "AmazonCoBuy", "Coauthor", "CoraFull"] from .utils import (
_get_dgl_url,
deprecate_class,
deprecate_property,
load_graphs,
save_graphs,
)
__all__ = [
"AmazonCoBuyComputerDataset",
"AmazonCoBuyPhotoDataset",
"CoauthorPhysicsDataset",
"CoauthorCSDataset",
"CoraFullDataset",
"AmazonCoBuy",
"Coauthor",
"CoraFull",
]
def eliminate_self_loops(A): def eliminate_self_loops(A):
...@@ -27,36 +42,50 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -27,36 +42,50 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
Reference: https://github.com/shchur/gnn-benchmark#datasets Reference: https://github.com/shchur/gnn-benchmark#datasets
""" """
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None):
_url = _get_dgl_url('dataset/' + name + '.zip') def __init__(
super(GNNBenchmarkDataset, self).__init__(name=name, self,
url=_url, name,
raw_dir=raw_dir, raw_dir=None,
force_reload=force_reload, force_reload=False,
verbose=verbose, verbose=False,
transform=transform) transform=None,
):
_url = _get_dgl_url("dataset/" + name + ".zip")
super(GNNBenchmarkDataset, self).__init__(
name=name,
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self): def process(self):
npz_path = os.path.join(self.raw_path, self.name + '.npz') npz_path = os.path.join(self.raw_path, self.name + ".npz")
g = self._load_npz(npz_path) g = self._load_npz(npz_path)
g = transforms.reorder_graph( g = transforms.reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
self._graph = g self._graph = g
self._data = [g] self._data = [g]
self._print_info() self._print_info()
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin') graph_path = os.path.join(self.save_path, "dgl_graph_v1.bin")
if os.path.exists(graph_path): if os.path.exists(graph_path):
return True return True
return False return False
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin') graph_path = os.path.join(self.save_path, "dgl_graph_v1.bin")
save_graphs(graph_path, self._graph) save_graphs(graph_path, self._graph)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin') graph_path = os.path.join(self.save_path, "dgl_graph_v1.bin")
graphs, _ = load_graphs(graph_path) graphs, _ = load_graphs(graph_path)
self._graph = graphs[0] self._graph = graphs[0]
self._data = [graphs[0]] self._data = [graphs[0]]
...@@ -64,41 +93,59 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -64,41 +93,59 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
def _print_info(self): def _print_info(self):
if self.verbose: if self.verbose:
print(' NumNodes: {}'.format(self._graph.number_of_nodes())) print(" NumNodes: {}".format(self._graph.number_of_nodes()))
print(' NumEdges: {}'.format(self._graph.number_of_edges())) print(" NumEdges: {}".format(self._graph.number_of_edges()))
print(' NumFeats: {}'.format(self._graph.ndata['feat'].shape[-1])) print(" NumFeats: {}".format(self._graph.ndata["feat"].shape[-1]))
print(' NumbClasses: {}'.format(self.num_classes)) print(" NumbClasses: {}".format(self.num_classes))
def _load_npz(self, file_name): def _load_npz(self, file_name):
with np.load(file_name, allow_pickle=True) as loader: with np.load(file_name, allow_pickle=True) as loader:
loader = dict(loader) loader = dict(loader)
num_nodes = loader['adj_shape'][0] num_nodes = loader["adj_shape"][0]
adj_matrix = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], loader['adj_indptr']), adj_matrix = sp.csr_matrix(
shape=loader['adj_shape']).tocoo() (
loader["adj_data"],
if 'attr_data' in loader: loader["adj_indices"],
loader["adj_indptr"],
),
shape=loader["adj_shape"],
).tocoo()
if "attr_data" in loader:
# Attributes are stored as a sparse CSR matrix # Attributes are stored as a sparse CSR matrix
attr_matrix = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], loader['attr_indptr']), attr_matrix = sp.csr_matrix(
shape=loader['attr_shape']).todense() (
elif 'attr_matrix' in loader: loader["attr_data"],
loader["attr_indices"],
loader["attr_indptr"],
),
shape=loader["attr_shape"],
).todense()
elif "attr_matrix" in loader:
# Attributes are stored as a (dense) np.ndarray # Attributes are stored as a (dense) np.ndarray
attr_matrix = loader['attr_matrix'] attr_matrix = loader["attr_matrix"]
else: else:
attr_matrix = None attr_matrix = None
if 'labels_data' in loader: if "labels_data" in loader:
# Labels are stored as a CSR matrix # Labels are stored as a CSR matrix
labels = sp.csr_matrix((loader['labels_data'], loader['labels_indices'], loader['labels_indptr']), labels = sp.csr_matrix(
shape=loader['labels_shape']).todense() (
elif 'labels' in loader: loader["labels_data"],
loader["labels_indices"],
loader["labels_indptr"],
),
shape=loader["labels_shape"],
).todense()
elif "labels" in loader:
# Labels are stored as a numpy array # Labels are stored as a numpy array
labels = loader['labels'] labels = loader["labels"]
else: else:
labels = None labels = None
g = dgl_graph((adj_matrix.row, adj_matrix.col)) g = dgl_graph((adj_matrix.row, adj_matrix.col))
g = transforms.to_bidirected(g) g = transforms.to_bidirected(g)
g.ndata['feat'] = F.tensor(attr_matrix, F.data_type_dict['float32']) g.ndata["feat"] = F.tensor(attr_matrix, F.data_type_dict["float32"])
g.ndata['label'] = F.tensor(labels, F.data_type_dict['int64']) g.ndata["label"] = F.tensor(labels, F.data_type_dict["int64"])
return g return g
@property @property
...@@ -107,7 +154,7 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -107,7 +154,7 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
raise NotImplementedError raise NotImplementedError
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph by index r"""Get graph by index
Parameters Parameters
---------- ----------
...@@ -176,12 +223,17 @@ class CoraFullDataset(GNNBenchmarkDataset): ...@@ -176,12 +223,17 @@ class CoraFullDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoraFullDataset, self).__init__(name="cora_full", def __init__(
raw_dir=raw_dir, self, raw_dir=None, force_reload=False, verbose=False, transform=None
force_reload=force_reload, ):
verbose=verbose, super(CoraFullDataset, self).__init__(
transform=transform) name="cora_full",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -195,7 +247,7 @@ class CoraFullDataset(GNNBenchmarkDataset): ...@@ -195,7 +247,7 @@ class CoraFullDataset(GNNBenchmarkDataset):
class CoauthorCSDataset(GNNBenchmarkDataset): class CoauthorCSDataset(GNNBenchmarkDataset):
r""" 'Computer Science (CS)' part of the Coauthor dataset for node classification task. r"""'Computer Science (CS)' part of the Coauthor dataset for node classification task.
Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph
from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they
...@@ -239,12 +291,17 @@ class CoauthorCSDataset(GNNBenchmarkDataset): ...@@ -239,12 +291,17 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoauthorCSDataset, self).__init__(name='coauthor_cs', def __init__(
raw_dir=raw_dir, self, raw_dir=None, force_reload=False, verbose=False, transform=None
force_reload=force_reload, ):
verbose=verbose, super(CoauthorCSDataset, self).__init__(
transform=transform) name="coauthor_cs",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -258,7 +315,7 @@ class CoauthorCSDataset(GNNBenchmarkDataset): ...@@ -258,7 +315,7 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
class CoauthorPhysicsDataset(GNNBenchmarkDataset): class CoauthorPhysicsDataset(GNNBenchmarkDataset):
r""" 'Physics' part of the Coauthor dataset for node classification task. r"""'Physics' part of the Coauthor dataset for node classification task.
Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph
from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they
...@@ -302,12 +359,17 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset): ...@@ -302,12 +359,17 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoauthorPhysicsDataset, self).__init__(name='coauthor_physics', def __init__(
raw_dir=raw_dir, self, raw_dir=None, force_reload=False, verbose=False, transform=None
force_reload=force_reload, ):
verbose=verbose, super(CoauthorPhysicsDataset, self).__init__(
transform=transform) name="coauthor_physics",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -321,7 +383,7 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset): ...@@ -321,7 +383,7 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
class AmazonCoBuyComputerDataset(GNNBenchmarkDataset): class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
r""" 'Computer' part of the AmazonCoBuy dataset for node classification task. r"""'Computer' part of the AmazonCoBuy dataset for node classification task.
Amazon Computers and Amazon Photo are segments of the Amazon co-purchase graph [McAuley et al., 2015], Amazon Computers and Amazon Photo are segments of the Amazon co-purchase graph [McAuley et al., 2015],
where nodes represent goods, edges indicate that two goods are frequently bought together, node where nodes represent goods, edges indicate that two goods are frequently bought together, node
...@@ -364,12 +426,17 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset): ...@@ -364,12 +426,17 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(AmazonCoBuyComputerDataset, self).__init__(name='amazon_co_buy_computer', def __init__(
raw_dir=raw_dir, self, raw_dir=None, force_reload=False, verbose=False, transform=None
force_reload=force_reload, ):
verbose=verbose, super(AmazonCoBuyComputerDataset, self).__init__(
transform=transform) name="amazon_co_buy_computer",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -426,12 +493,17 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset): ...@@ -426,12 +493,17 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(AmazonCoBuyPhotoDataset, self).__init__(name='amazon_co_buy_photo', def __init__(
raw_dir=raw_dir, self, raw_dir=None, force_reload=False, verbose=False, transform=None
force_reload=force_reload, ):
verbose=verbose, super(AmazonCoBuyPhotoDataset, self).__init__(
transform=transform) name="amazon_co_buy_photo",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -446,27 +518,27 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset): ...@@ -446,27 +518,27 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
class CoraFull(CoraFullDataset): class CoraFull(CoraFullDataset):
def __init__(self, **kwargs): def __init__(self, **kwargs):
deprecate_class('CoraFull', 'CoraFullDataset') deprecate_class("CoraFull", "CoraFullDataset")
super(CoraFull, self).__init__(**kwargs) super(CoraFull, self).__init__(**kwargs)
def AmazonCoBuy(name): def AmazonCoBuy(name):
if name == 'computers': if name == "computers":
deprecate_class('AmazonCoBuy', 'AmazonCoBuyComputerDataset') deprecate_class("AmazonCoBuy", "AmazonCoBuyComputerDataset")
return AmazonCoBuyComputerDataset() return AmazonCoBuyComputerDataset()
elif name == 'photo': elif name == "photo":
deprecate_class('AmazonCoBuy', 'AmazonCoBuyPhotoDataset') deprecate_class("AmazonCoBuy", "AmazonCoBuyPhotoDataset")
return AmazonCoBuyPhotoDataset() return AmazonCoBuyPhotoDataset()
else: else:
raise ValueError('Dataset name should be "computers" or "photo".') raise ValueError('Dataset name should be "computers" or "photo".')
def Coauthor(name): def Coauthor(name):
if name == 'cs': if name == "cs":
deprecate_class('Coauthor', 'CoauthorCSDataset') deprecate_class("Coauthor", "CoauthorCSDataset")
return CoauthorCSDataset() return CoauthorCSDataset()
elif name == 'physics': elif name == "physics":
deprecate_class('Coauthor', 'CoauthorPhysicsDataset') deprecate_class("Coauthor", "CoauthorPhysicsDataset")
return CoauthorPhysicsDataset() return CoauthorPhysicsDataset()
else: else:
raise ValueError('Dataset name should be "cs" or "physics".') raise ValueError('Dataset name should be "cs" or "physics".')
"""For Graph Serialization""" """For Graph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
import os import os
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
from ..base import DGLError, dgl_warning from ..base import dgl_warning, DGLError
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from .heterograph_serialize import save_heterographs from .heterograph_serialize import save_heterographs
_init_api("dgl.data.graph_serialize") _init_api("dgl.data.graph_serialize")
__all__ = ["save_graphs", "load_graphs", "load_labels"] __all__ = ["save_graphs", "load_graphs", "load_labels"]
@register_object("graph_serialize.StorageMetaData") @register_object("graph_serialize.StorageMetaData")
class StorageMetaData(ObjectBase): class StorageMetaData(ObjectBase):
"""StorageMetaData Object """StorageMetaData Object
attributes available: attributes available:
num_graph [int]: return numbers of graphs num_graph [int]: return numbers of graphs
nodes_num_list Value of NDArray: return number of nodes for each graph nodes_num_list Value of NDArray: return number of nodes for each graph
edges_num_list Value of NDArray: return number of edges for each graph edges_num_list Value of NDArray: return number of edges for each graph
labels [dict of backend tensors]: return dict of labels labels [dict of backend tensors]: return dict of labels
graph_data [list of GraphData]: return list of GraphData Object graph_data [list of GraphData]: return list of GraphData Object
""" """
def is_local_path(filepath): def is_local_path(filepath):
return not ( return not (
filepath.startswith("hdfs://") filepath.startswith("hdfs://")
or filepath.startswith("viewfs://") or filepath.startswith("viewfs://")
or filepath.startswith("s3://") or filepath.startswith("s3://")
) )
def check_local_file_exists(filename): def check_local_file_exists(filename):
if is_local_path(filename) and not os.path.exists(filename): if is_local_path(filename) and not os.path.exists(filename):
raise DGLError("File {} does not exist.".format(filename)) raise DGLError("File {} does not exist.".format(filename))
@register_object("graph_serialize.GraphData") @register_object("graph_serialize.GraphData")
class GraphData(ObjectBase): class GraphData(ObjectBase):
"""GraphData Object""" """GraphData Object"""
@staticmethod @staticmethod
def create(g): def create(g):
"""Create GraphData""" """Create GraphData"""
# TODO(zihao): support serialize batched graph in the future. # TODO(zihao): support serialize batched graph in the future.
assert ( assert (
g.batch_size == 1 g.batch_size == 1
), "Batched DGLGraph is not supported for serialization" ), "Batched DGLGraph is not supported for serialization"
ghandle = g._graph ghandle = g._graph
if len(g.ndata) != 0: if len(g.ndata) != 0:
node_tensors = dict() node_tensors = dict()
for key, value in g.ndata.items(): for key, value in g.ndata.items():
node_tensors[key] = F.zerocopy_to_dgl_ndarray(value) node_tensors[key] = F.zerocopy_to_dgl_ndarray(value)
else: else:
node_tensors = None node_tensors = None
if len(g.edata) != 0:
if len(g.edata) != 0: edge_tensors = dict()
edge_tensors = dict() for key, value in g.edata.items():
for key, value in g.edata.items(): edge_tensors[key] = F.zerocopy_to_dgl_ndarray(value)
edge_tensors[key] = F.zerocopy_to_dgl_ndarray(value) else:
else: edge_tensors = None
edge_tensors = None return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors)
return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors) def get_graph(self):
"""Get DGLGraph from GraphData"""
def get_graph(self): ghandle = _CAPI_GDataGraphHandle(self)
"""Get DGLGraph from GraphData""" hgi = _CAPI_DGLAsHeteroGraph(ghandle)
ghandle = _CAPI_GDataGraphHandle(self) g = DGLGraph(hgi, ["_U"], ["_E"])
hgi = _CAPI_DGLAsHeteroGraph(ghandle) node_tensors_items = _CAPI_GDataNodeTensors(self).items()
g = DGLGraph(hgi, ["_U"], ["_E"]) edge_tensors_items = _CAPI_GDataEdgeTensors(self).items()
node_tensors_items = _CAPI_GDataNodeTensors(self).items() for k, v in node_tensors_items:
edge_tensors_items = _CAPI_GDataEdgeTensors(self).items() g.ndata[k] = F.zerocopy_from_dgl_ndarray(v)
for k, v in node_tensors_items: for k, v in edge_tensors_items:
g.ndata[k] = F.zerocopy_from_dgl_ndarray(v) g.edata[k] = F.zerocopy_from_dgl_ndarray(v)
for k, v in edge_tensors_items: return g
g.edata[k] = F.zerocopy_from_dgl_ndarray(v)
return g
def save_graphs(filename, g_list, labels=None, formats=None):
r"""Save graphs and optionally their labels to file.
def save_graphs(filename, g_list, labels=None, formats=None):
r"""Save graphs and optionally their labels to file. Besides saving to local files, DGL supports writing the graphs directly
to S3 (by providing a ``"s3://..."`` path) or to HDFS (by providing
Besides saving to local files, DGL supports writing the graphs directly ``"hdfs://..."`` a path).
to S3 (by providing a ``"s3://..."`` path) or to HDFS (by providing
``"hdfs://..."`` a path). The function saves both the graph structure and node/edge features to file
in DGL's own binary format. For graph-level features, pass them via
The function saves both the graph structure and node/edge features to file the :attr:`labels` argument.
in DGL's own binary format. For graph-level features, pass them via
the :attr:`labels` argument. Parameters
----------
Parameters filename : str
---------- The file name to store the graphs and labels.
filename : str g_list: list
The file name to store the graphs and labels. The graphs to be saved.
g_list: list labels: dict[str, Tensor]
The graphs to be saved. labels should be dict of tensors, with str as keys
labels: dict[str, Tensor] formats: str or list[str]
labels should be dict of tensors, with str as keys Save graph in specified formats. It could be any combination of
formats: str or list[str] ``coo``, ``csc`` and ``csr``. If not specified, save one format
Save graph in specified formats. It could be any combination of only according to what format is available. If multiple formats
``coo``, ``csc`` and ``csr``. If not specified, save one format are available, selection priority from high to low is ``coo``,
only according to what format is available. If multiple formats ``csc``, ``csr``.
are available, selection priority from high to low is ``coo``,
``csc``, ``csr``. Examples
----------
Examples >>> import dgl
---------- >>> import torch as th
>>> import dgl
>>> import torch as th Create :class:`DGLGraph` objects and initialize node
and edge features.
Create :class:`DGLGraph` objects and initialize node
and edge features. >>> g1 = dgl.graph(([0, 1, 2], [1, 2, 3]))
>>> g2 = dgl.graph(([0, 2], [2, 3]))
>>> g1 = dgl.graph(([0, 1, 2], [1, 2, 3])) >>> g2.edata["e"] = th.ones(2, 4)
>>> g2 = dgl.graph(([0, 2], [2, 3]))
>>> g2.edata["e"] = th.ones(2, 4) Save Graphs into file
Save Graphs into file >>> from dgl.data.utils import save_graphs
>>> graph_labels = {"glabel": th.tensor([0, 1])}
>>> from dgl.data.utils import save_graphs >>> save_graphs("./data.bin", [g1, g2], graph_labels)
>>> graph_labels = {"glabel": th.tensor([0, 1])}
>>> save_graphs("./data.bin", [g1, g2], graph_labels) See Also
--------
See Also load_graphs
-------- """
load_graphs # if it is local file, do some sanity check
""" if is_local_path(filename):
# if it is local file, do some sanity check if os.path.isdir(filename):
if is_local_path(filename): raise DGLError(
if os.path.isdir(filename): "Filename {} is an existing directory.".format(filename)
raise DGLError( )
"Filename {} is an existing directory.".format(filename) f_path = os.path.dirname(filename)
) if f_path and not os.path.exists(f_path):
f_path = os.path.dirname(filename) os.makedirs(f_path)
if f_path and not os.path.exists(f_path): g_sample = g_list[0] if isinstance(g_list, list) else g_list
os.makedirs(f_path) if type(g_sample) == DGLGraph: # Doesn't support DGLGraph's derived class
save_heterographs(filename, g_list, labels, formats)
g_sample = g_list[0] if isinstance(g_list, list) else g_list else:
if ( raise DGLError(
type(g_sample) == DGLGraph "Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs."
): # Doesn't support DGLGraph's derived class )
save_heterographs(filename, g_list, labels, formats)
else:
raise DGLError( def load_graphs(filename, idx_list=None):
"Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs." """Load graphs and optionally their labels from file saved by :func:`save_graphs`.
)
Besides loading from local files, DGL supports loading the graphs directly
from S3 (by providing a ``"s3://..."`` path) or from HDFS (by providing
def load_graphs(filename, idx_list=None): ``"hdfs://..."`` a path).
"""Load graphs and optionally their labels from file saved by :func:`save_graphs`.
Parameters
Besides loading from local files, DGL supports loading the graphs directly ----------
from S3 (by providing a ``"s3://..."`` path) or from HDFS (by providing filename: str
``"hdfs://..."`` a path). The file name to load graphs from.
idx_list: list[int], optional
Parameters The indices of the graphs to be loaded if the file contains multiple graphs.
---------- Default is loading all the graphs stored in the file.
filename: str
The file name to load graphs from. Returns
idx_list: list[int], optional --------
The indices of the graphs to be loaded if the file contains multiple graphs. graph_list: list[DGLGraph]
Default is loading all the graphs stored in the file. The loaded graphs.
labels: dict[str, Tensor]
Returns The graph labels stored in file. If no label is stored, the dictionary is empty.
-------- Regardless of whether the ``idx_list`` argument is given or not,
graph_list: list[DGLGraph] the returned dictionary always contains the labels of all the graphs.
The loaded graphs.
labels: dict[str, Tensor] Examples
The graph labels stored in file. If no label is stored, the dictionary is empty. ----------
Regardless of whether the ``idx_list`` argument is given or not, Following the example in :func:`save_graphs`.
the returned dictionary always contains the labels of all the graphs.
>>> from dgl.data.utils import load_graphs
Examples >>> glist, label_dict = load_graphs("./data.bin") # glist will be [g1, g2]
---------- >>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1]
Following the example in :func:`save_graphs`.
See Also
>>> from dgl.data.utils import load_graphs --------
>>> glist, label_dict = load_graphs("./data.bin") # glist will be [g1, g2] save_graphs
>>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1] """
# if it is local file, do some sanity check
See Also check_local_file_exists(filename)
-------- version = _CAPI_GetFileVersion(filename)
save_graphs if version == 1:
""" dgl_warning(
# if it is local file, do some sanity check
check_local_file_exists(filename)
version = _CAPI_GetFileVersion(filename)
if version == 1:
dgl_warning(
"You are loading a graph file saved by old version of dgl. \ "You are loading a graph file saved by old version of dgl. \
Please consider saving it again with the current format." Please consider saving it again with the current format."
) )
return load_graph_v1(filename, idx_list) return load_graph_v1(filename, idx_list)
elif version == 2: elif version == 2:
return load_graph_v2(filename, idx_list) return load_graph_v2(filename, idx_list)
else: else:
raise DGLError("Invalid DGL Version Number.") raise DGLError("Invalid DGL Version Number.")
def load_graph_v2(filename, idx_list=None): def load_graph_v2(filename, idx_list=None):
"""Internal functions for loading DGLGraphs.""" """Internal functions for loading DGLGraphs."""
if idx_list is None: if idx_list is None:
idx_list = [] idx_list = []
assert isinstance(idx_list, list) assert isinstance(idx_list, list)
heterograph_list = _CAPI_LoadGraphFiles_V2(filename, idx_list) heterograph_list = _CAPI_LoadGraphFiles_V2(filename, idx_list)
label_dict = load_labels_v2(filename) label_dict = load_labels_v2(filename)
return [gdata.get_graph() for gdata in heterograph_list], label_dict return [gdata.get_graph() for gdata in heterograph_list], label_dict
def load_graph_v1(filename, idx_list=None): def load_graph_v1(filename, idx_list=None):
""" "Internal functions for loading DGLGraphs (V0).""" """ "Internal functions for loading DGLGraphs (V0)."""
if idx_list is None: if idx_list is None:
idx_list = [] idx_list = []
assert isinstance(idx_list, list) assert isinstance(idx_list, list)
metadata = _CAPI_LoadGraphFiles_V1(filename, idx_list, False) metadata = _CAPI_LoadGraphFiles_V1(filename, idx_list, False)
label_dict = {} label_dict = {}
for k, v in metadata.labels.items(): for k, v in metadata.labels.items():
label_dict[k] = F.zerocopy_from_dgl_ndarray(v) label_dict[k] = F.zerocopy_from_dgl_ndarray(v)
return [gdata.get_graph() for gdata in metadata.graph_data], label_dict
return [gdata.get_graph() for gdata in metadata.graph_data], label_dict
def load_labels(filename):
def load_labels(filename): """
""" Load label dict from file
Load label dict from file
Parameters
Parameters ----------
---------- filename: str
filename: str filename to load DGLGraphs
filename to load DGLGraphs
Returns
Returns ----------
---------- labels: dict
labels: dict dict of labels stored in file (empty dict returned if no
dict of labels stored in file (empty dict returned if no label stored)
label stored)
Examples
Examples ----------
---------- Following the example in save_graphs.
Following the example in save_graphs.
>>> from dgl.data.utils import load_labels
>>> from dgl.data.utils import load_labels >>> label_dict = load_graphs("./data.bin")
>>> label_dict = load_graphs("./data.bin")
"""
""" # if it is local file, do some sanity check
# if it is local file, do some sanity check check_local_file_exists(filename)
check_local_file_exists(filename)
version = _CAPI_GetFileVersion(filename)
version = _CAPI_GetFileVersion(filename) if version == 1:
if version == 1: return load_labels_v1(filename)
return load_labels_v1(filename) elif version == 2:
elif version == 2: return load_labels_v2(filename)
return load_labels_v2(filename) else:
else: raise Exception("Invalid DGL Version Number")
raise Exception("Invalid DGL Version Number")
def load_labels_v2(filename):
def load_labels_v2(filename): """Internal functions for loading labels from V2 format"""
"""Internal functions for loading labels from V2 format""" label_dict = {}
label_dict = {} nd_dict = _CAPI_LoadLabels_V2(filename)
nd_dict = _CAPI_LoadLabels_V2(filename) for k, v in nd_dict.items():
for k, v in nd_dict.items(): label_dict[k] = F.zerocopy_from_dgl_ndarray(v)
label_dict[k] = F.zerocopy_from_dgl_ndarray(v) return label_dict
return label_dict
def load_labels_v1(filename):
def load_labels_v1(filename): """Internal functions for loading labels from V1 format"""
"""Internal functions for loading labels from V1 format""" metadata = _CAPI_LoadGraphFiles_V1(filename, [], True)
metadata = _CAPI_LoadGraphFiles_V1(filename, [], True) label_dict = {}
label_dict = {} for k, v in metadata.labels.items():
for k, v in metadata.labels.items(): label_dict[k] = F.zerocopy_from_dgl_ndarray(v)
label_dict[k] = F.zerocopy_from_dgl_ndarray(v) return label_dict
return label_dict
...@@ -37,6 +37,7 @@ def save_heterographs(filename, g_list, labels, formats): ...@@ -37,6 +37,7 @@ def save_heterographs(filename, g_list, labels, formats):
filename, gdata_list, tensor_dict_to_ndarray_dict(labels), formats filename, gdata_list, tensor_dict_to_ndarray_dict(labels), formats
) )
@register_object("heterograph_serialize.HeteroGraphData") @register_object("heterograph_serialize.HeteroGraphData")
class HeteroGraphData(ObjectBase): class HeteroGraphData(ObjectBase):
"""Object to hold the data to be stored for DGLGraph""" """Object to hold the data to be stored for DGLGraph"""
......
"""KarateClub Dataset """KarateClub Dataset
""" """
import numpy as np
import networkx as nx import networkx as nx
import numpy as np
from .. import backend as F from .. import backend as F
from ..convert import from_networkx
from .dgl_dataset import DGLDataset from .dgl_dataset import DGLDataset
from .utils import deprecate_property from .utils import deprecate_property
from ..convert import from_networkx
__all__ = ['KarateClubDataset', 'KarateClub'] __all__ = ["KarateClubDataset", "KarateClub"]
class KarateClubDataset(DGLDataset): class KarateClubDataset(DGLDataset):
r""" Karate Club dataset for Node Classification r"""Karate Club dataset for Node Classification
Zachary's karate club is a social network of a university Zachary's karate club is a social network of a university
karate club, described in the paper "An Information Flow karate club, described in the paper "An Information Flow
...@@ -46,16 +46,20 @@ class KarateClubDataset(DGLDataset): ...@@ -46,16 +46,20 @@ class KarateClubDataset(DGLDataset):
>>> g = dataset[0] >>> g = dataset[0]
>>> labels = g.ndata['label'] >>> labels = g.ndata['label']
""" """
def __init__(self, transform=None): def __init__(self, transform=None):
super(KarateClubDataset, self).__init__(name='karate_club', transform=transform) super(KarateClubDataset, self).__init__(
name="karate_club", transform=transform
)
def process(self): def process(self):
kc_graph = nx.karate_club_graph() kc_graph = nx.karate_club_graph()
label = np.asarray( label = np.asarray(
[kc_graph.nodes[i]['club'] != 'Mr. Hi' for i in kc_graph.nodes]).astype(np.int64) [kc_graph.nodes[i]["club"] != "Mr. Hi" for i in kc_graph.nodes]
).astype(np.int64)
label = F.tensor(label) label = F.tensor(label)
g = from_networkx(kc_graph) g = from_networkx(kc_graph)
g.ndata['label'] = label g.ndata["label"] = label
self._graph = g self._graph = g
self._data = [g] self._data = [g]
...@@ -65,7 +69,7 @@ class KarateClubDataset(DGLDataset): ...@@ -65,7 +69,7 @@ class KarateClubDataset(DGLDataset):
return 2 return 2
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph object r"""Get graph object
Parameters Parameters
---------- ----------
......
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import os, sys
import pickle as pkl import pickle as pkl
import networkx as nx import networkx as nx
import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import os, sys
from .dgl_dataset import DGLBuiltinDataset
from .utils import download, extract_archive, get_download_dir
from .utils import save_graphs, load_graphs, save_info, load_info, makedirs, _get_dgl_url
from .utils import generate_mask_tensor
from .utils import deprecate_property, deprecate_function
from ..utils import retry_method_with_fix
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from ..utils import retry_method_with_fix
from .dgl_dataset import DGLBuiltinDataset
from .utils import (
_get_dgl_url,
deprecate_function,
deprecate_property,
download,
extract_archive,
generate_mask_tensor,
get_download_dir,
load_graphs,
load_info,
makedirs,
save_graphs,
save_info,
)
class KnowledgeGraphDataset(DGLBuiltinDataset): class KnowledgeGraphDataset(DGLBuiltinDataset):
"""KnowledgeGraph link prediction dataset """KnowledgeGraph link prediction dataset
...@@ -41,22 +55,31 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -41,22 +55,31 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
a transformed version. The :class:`~dgl.DGLGraph` object will be a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access. transformed before every access.
""" """
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None): def __init__(
self,
name,
reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
self._name = name self._name = name
self.reverse = reverse self.reverse = reverse
url = _get_dgl_url('dataset/') + '{}.tgz'.format(name) url = _get_dgl_url("dataset/") + "{}.tgz".format(name)
super(KnowledgeGraphDataset, self).__init__(name, super(KnowledgeGraphDataset, self).__init__(
url=url, name,
raw_dir=raw_dir, url=url,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def download(self): def download(self):
r""" Automatically download data and extract it. r"""Automatically download data and extract it."""
""" tgz_path = os.path.join(self.raw_dir, self.name + ".tgz")
tgz_path = os.path.join(self.raw_dir, self.name + '.tgz')
download(self.url, path=tgz_path) download(self.url, path=tgz_path)
extract_archive(tgz_path, self.raw_path) extract_archive(tgz_path, self.raw_path)
...@@ -66,16 +89,22 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -66,16 +89,22 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
This function will parse these triplets and build the DGLGraph. This function will parse these triplets and build the DGLGraph.
""" """
root_path = self.raw_path root_path = self.raw_path
entity_path = os.path.join(root_path, 'entities.dict') entity_path = os.path.join(root_path, "entities.dict")
relation_path = os.path.join(root_path, 'relations.dict') relation_path = os.path.join(root_path, "relations.dict")
train_path = os.path.join(root_path, 'train.txt') train_path = os.path.join(root_path, "train.txt")
valid_path = os.path.join(root_path, 'valid.txt') valid_path = os.path.join(root_path, "valid.txt")
test_path = os.path.join(root_path, 'test.txt') test_path = os.path.join(root_path, "test.txt")
entity_dict = _read_dictionary(entity_path) entity_dict = _read_dictionary(entity_path)
relation_dict = _read_dictionary(relation_path) relation_dict = _read_dictionary(relation_path)
train = np.asarray(_read_triplets_as_list(train_path, entity_dict, relation_dict)) train = np.asarray(
valid = np.asarray(_read_triplets_as_list(valid_path, entity_dict, relation_dict)) _read_triplets_as_list(train_path, entity_dict, relation_dict)
test = np.asarray(_read_triplets_as_list(test_path, entity_dict, relation_dict)) )
valid = np.asarray(
_read_triplets_as_list(valid_path, entity_dict, relation_dict)
)
test = np.asarray(
_read_triplets_as_list(test_path, entity_dict, relation_dict)
)
num_nodes = len(entity_dict) num_nodes = len(entity_dict)
num_rels = len(relation_dict) num_rels = len(relation_dict)
if self.verbose: if self.verbose:
...@@ -93,25 +122,33 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -93,25 +122,33 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
self._num_nodes = num_nodes self._num_nodes = num_nodes
self._num_rels = num_rels self._num_rels = num_rels
# build graph # build graph
g, data = build_knowledge_graph(num_nodes, num_rels, train, valid, test, reverse=self.reverse) g, data = build_knowledge_graph(
etype, ntype, train_edge_mask, valid_edge_mask, test_edge_mask, train_mask, val_mask, test_mask = data num_nodes, num_rels, train, valid, test, reverse=self.reverse
g.edata['train_edge_mask'] = train_edge_mask )
g.edata['valid_edge_mask'] = valid_edge_mask (
g.edata['test_edge_mask'] = test_edge_mask etype,
g.edata['train_mask'] = train_mask ntype,
g.edata['val_mask'] = val_mask train_edge_mask,
g.edata['test_mask'] = test_mask valid_edge_mask,
g.edata['etype'] = etype test_edge_mask,
g.ndata['ntype'] = ntype train_mask,
val_mask,
test_mask,
) = data
g.edata["train_edge_mask"] = train_edge_mask
g.edata["valid_edge_mask"] = valid_edge_mask
g.edata["test_edge_mask"] = test_edge_mask
g.edata["train_mask"] = train_mask
g.edata["val_mask"] = val_mask
g.edata["test_mask"] = test_mask
g.edata["etype"] = etype
g.ndata["ntype"] = ntype
self._g = g self._g = g
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.save_name + ".bin")
self.save_name + '.bin') info_path = os.path.join(self.save_path, self.save_name + ".pkl")
info_path = os.path.join(self.save_path, if os.path.exists(graph_path) and os.path.exists(info_path):
self.save_name + '.pkl')
if os.path.exists(graph_path) and \
os.path.exists(info_path):
return True return True
return False return False
...@@ -128,49 +165,65 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -128,49 +165,65 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.save_name + ".bin")
self.save_name + '.bin') info_path = os.path.join(self.save_path, self.save_name + ".pkl")
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
save_graphs(str(graph_path), self._g) save_graphs(str(graph_path), self._g)
save_info(str(info_path), {'num_nodes': self.num_nodes, save_info(
'num_rels': self.num_rels}) str(info_path),
{"num_nodes": self.num_nodes, "num_rels": self.num_rels},
)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.save_name + ".bin")
self.save_name + '.bin') info_path = os.path.join(self.save_path, self.save_name + ".pkl")
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
graphs, _ = load_graphs(str(graph_path)) graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path)) info = load_info(str(info_path))
self._num_nodes = info['num_nodes'] self._num_nodes = info["num_nodes"]
self._num_rels = info['num_rels'] self._num_rels = info["num_rels"]
self._g = graphs[0] self._g = graphs[0]
train_mask = self._g.edata['train_edge_mask'].numpy() train_mask = self._g.edata["train_edge_mask"].numpy()
val_mask = self._g.edata['valid_edge_mask'].numpy() val_mask = self._g.edata["valid_edge_mask"].numpy()
test_mask = self._g.edata['test_edge_mask'].numpy() test_mask = self._g.edata["test_edge_mask"].numpy()
# convert mask tensor into bool tensor if possible # convert mask tensor into bool tensor if possible
self._g.edata['train_edge_mask'] = generate_mask_tensor(self._g.edata['train_edge_mask'].numpy()) self._g.edata["train_edge_mask"] = generate_mask_tensor(
self._g.edata['valid_edge_mask'] = generate_mask_tensor(self._g.edata['valid_edge_mask'].numpy()) self._g.edata["train_edge_mask"].numpy()
self._g.edata['test_edge_mask'] = generate_mask_tensor(self._g.edata['test_edge_mask'].numpy()) )
self._g.edata['train_mask'] = generate_mask_tensor(self._g.edata['train_mask'].numpy()) self._g.edata["valid_edge_mask"] = generate_mask_tensor(
self._g.edata['val_mask'] = generate_mask_tensor(self._g.edata['val_mask'].numpy()) self._g.edata["valid_edge_mask"].numpy()
self._g.edata['test_mask'] = generate_mask_tensor(self._g.edata['test_mask'].numpy()) )
self._g.edata["test_edge_mask"] = generate_mask_tensor(
self._g.edata["test_edge_mask"].numpy()
)
self._g.edata["train_mask"] = generate_mask_tensor(
self._g.edata["train_mask"].numpy()
)
self._g.edata["val_mask"] = generate_mask_tensor(
self._g.edata["val_mask"].numpy()
)
self._g.edata["test_mask"] = generate_mask_tensor(
self._g.edata["test_mask"].numpy()
)
# for compatability (with 0.4.x) generate train_idx, valid_idx and test_idx # for compatability (with 0.4.x) generate train_idx, valid_idx and test_idx
etype = self._g.edata['etype'].numpy() etype = self._g.edata["etype"].numpy()
self._etype = etype self._etype = etype
u, v = self._g.all_edges(form='uv') u, v = self._g.all_edges(form="uv")
u = u.numpy() u = u.numpy()
v = v.numpy() v = v.numpy()
train_idx = np.nonzero(train_mask==1) train_idx = np.nonzero(train_mask == 1)
self._train = np.column_stack((u[train_idx], etype[train_idx], v[train_idx])) self._train = np.column_stack(
valid_idx = np.nonzero(val_mask==1) (u[train_idx], etype[train_idx], v[train_idx])
self._valid = np.column_stack((u[valid_idx], etype[valid_idx], v[valid_idx])) )
test_idx = np.nonzero(test_mask==1) valid_idx = np.nonzero(val_mask == 1)
self._test = np.column_stack((u[test_idx], etype[test_idx], v[test_idx])) self._valid = np.column_stack(
(u[valid_idx], etype[valid_idx], v[valid_idx])
)
test_idx = np.nonzero(test_mask == 1)
self._test = np.column_stack(
(u[test_idx], etype[test_idx], v[test_idx])
)
if self.verbose: if self.verbose:
print("# entities: {}".format(self.num_nodes)) print("# entities: {}".format(self.num_nodes))
...@@ -189,22 +242,25 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -189,22 +242,25 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
@property @property
def save_name(self): def save_name(self):
return self.name + '_dgl_graph' return self.name + "_dgl_graph"
def _read_dictionary(filename): def _read_dictionary(filename):
d = {} d = {}
with open(filename, 'r+') as f: with open(filename, "r+") as f:
for line in f: for line in f:
line = line.strip().split('\t') line = line.strip().split("\t")
d[line[1]] = int(line[0]) d[line[1]] = int(line[0])
return d return d
def _read_triplets(filename): def _read_triplets(filename):
with open(filename, 'r+') as f: with open(filename, "r+") as f:
for line in f: for line in f:
processed_line = line.strip().split('\t') processed_line = line.strip().split("\t")
yield processed_line yield processed_line
def _read_triplets_as_list(filename, entity_dict, relation_dict): def _read_triplets_as_list(filename, entity_dict, relation_dict):
l = [] l = []
for triplet in _read_triplets(filename): for triplet in _read_triplets(filename):
...@@ -214,9 +270,11 @@ def _read_triplets_as_list(filename, entity_dict, relation_dict): ...@@ -214,9 +270,11 @@ def _read_triplets_as_list(filename, entity_dict, relation_dict):
l.append([s, r, o]) l.append([s, r, o])
return l return l
def build_knowledge_graph(num_nodes, num_rels, train, valid, test, reverse=True):
""" Create a DGL Homogeneous graph with heterograph info stored as node or edge features. def build_knowledge_graph(
""" num_nodes, num_rels, train, valid, test, reverse=True
):
"""Create a DGL Homogeneous graph with heterograph info stored as node or edge features."""
src = [] src = []
rel = [] rel = []
dst = [] dst = []
...@@ -258,17 +316,17 @@ def build_knowledge_graph(num_nodes, num_rels, train, valid, test, reverse=True) ...@@ -258,17 +316,17 @@ def build_knowledge_graph(num_nodes, num_rels, train, valid, test, reverse=True)
for edge in train: for edge in train:
s, r, d = edge s, r, d = edge
assert r < num_rels assert r < num_rels
add_edge(s, r, d, reverse, 1) # train set add_edge(s, r, d, reverse, 1) # train set
for edge in valid: for edge in valid:
s, r, d = edge s, r, d = edge
assert r < num_rels assert r < num_rels
add_edge(s, r, d, reverse, 2) # valid set add_edge(s, r, d, reverse, 2) # valid set
for edge in test: for edge in test:
s, r, d = edge s, r, d = edge
assert r < num_rels assert r < num_rels
add_edge(s, r, d, reverse, 3) # test set add_edge(s, r, d, reverse, 3) # test set
subg = [] subg = []
fg_s = [] fg_s = []
...@@ -315,16 +373,40 @@ def build_knowledge_graph(num_nodes, num_rels, train, valid, test, reverse=True) ...@@ -315,16 +373,40 @@ def build_knowledge_graph(num_nodes, num_rels, train, valid, test, reverse=True)
g = dgl_graph((s, d), num_nodes=num_nodes) g = dgl_graph((s, d), num_nodes=num_nodes)
etype = np.concatenate(fg_etype) etype = np.concatenate(fg_etype)
settype = np.concatenate(fg_settype) settype = np.concatenate(fg_settype)
etype = F.tensor(etype, dtype=F.data_type_dict['int64']) etype = F.tensor(etype, dtype=F.data_type_dict["int64"])
train_edge_mask = train_edge_mask train_edge_mask = train_edge_mask
valid_edge_mask = valid_edge_mask valid_edge_mask = valid_edge_mask
test_edge_mask = test_edge_mask test_edge_mask = test_edge_mask
train_mask = generate_mask_tensor(settype == 1) if reverse is True else train_edge_mask train_mask = (
valid_mask = generate_mask_tensor(settype == 2) if reverse is True else valid_edge_mask generate_mask_tensor(settype == 1)
test_mask = generate_mask_tensor(settype == 3) if reverse is True else test_edge_mask if reverse is True
ntype = F.full_1d(num_nodes, 0, dtype=F.data_type_dict['int64'], ctx=F.cpu()) else train_edge_mask
)
valid_mask = (
generate_mask_tensor(settype == 2)
if reverse is True
else valid_edge_mask
)
test_mask = (
generate_mask_tensor(settype == 3)
if reverse is True
else test_edge_mask
)
ntype = F.full_1d(
num_nodes, 0, dtype=F.data_type_dict["int64"], ctx=F.cpu()
)
return g, (
etype,
ntype,
train_edge_mask,
valid_edge_mask,
test_edge_mask,
train_mask,
valid_mask,
test_mask,
)
return g, (etype, ntype, train_edge_mask, valid_edge_mask, test_edge_mask, train_mask, valid_mask, test_mask)
class FB15k237Dataset(KnowledgeGraphDataset): class FB15k237Dataset(KnowledgeGraphDataset):
r"""FB15k237 link prediction dataset. r"""FB15k237 link prediction dataset.
...@@ -396,11 +478,19 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -396,11 +478,19 @@ class FB15k237Dataset(KnowledgeGraphDataset):
>>> >>>
>>> # Train, Validation and Test >>> # Train, Validation and Test
""" """
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None): def __init__(
name = 'FB15k-237' self,
super(FB15k237Dataset, self).__init__(name, reverse, raw_dir, reverse=True,
force_reload, verbose, transform) raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
name = "FB15k-237"
super(FB15k237Dataset, self).__init__(
name, reverse, raw_dir, force_reload, verbose, transform
)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -431,6 +521,7 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -431,6 +521,7 @@ class FB15k237Dataset(KnowledgeGraphDataset):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(FB15k237Dataset, self).__len__() return super(FB15k237Dataset, self).__len__()
class FB15kDataset(KnowledgeGraphDataset): class FB15kDataset(KnowledgeGraphDataset):
r"""FB15k link prediction dataset. r"""FB15k link prediction dataset.
...@@ -504,11 +595,19 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -504,11 +595,19 @@ class FB15kDataset(KnowledgeGraphDataset):
>>> # Train, Validation and Test >>> # Train, Validation and Test
>>> >>>
""" """
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None): def __init__(
name = 'FB15k' self,
super(FB15kDataset, self).__init__(name, reverse, raw_dir, reverse=True,
force_reload, verbose, transform) raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
name = "FB15k"
super(FB15kDataset, self).__init__(
name, reverse, raw_dir, force_reload, verbose, transform
)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -539,8 +638,9 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -539,8 +638,9 @@ class FB15kDataset(KnowledgeGraphDataset):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(FB15kDataset, self).__len__() return super(FB15kDataset, self).__len__()
class WN18Dataset(KnowledgeGraphDataset): class WN18Dataset(KnowledgeGraphDataset):
r""" WN18 link prediction dataset. r"""WN18 link prediction dataset.
The WN18 dataset was introduced in `Translating Embeddings for Modeling The WN18 dataset was introduced in `Translating Embeddings for Modeling
Multi-relational Data <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_. Multi-relational Data <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_.
...@@ -611,11 +711,19 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -611,11 +711,19 @@ class WN18Dataset(KnowledgeGraphDataset):
>>> # Train, Validation and Test >>> # Train, Validation and Test
>>> >>>
""" """
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None): def __init__(
name = 'wn18' self,
super(WN18Dataset, self).__init__(name, reverse, raw_dir, reverse=True,
force_reload, verbose, transform) raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
name = "wn18"
super(WN18Dataset, self).__init__(
name, reverse, raw_dir, force_reload, verbose, transform
)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -646,6 +754,7 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -646,6 +754,7 @@ class WN18Dataset(KnowledgeGraphDataset):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(WN18Dataset, self).__len__() return super(WN18Dataset, self).__len__()
def load_data(dataset): def load_data(dataset):
r"""Load knowledge graph dataset for RGCN link prediction tasks r"""Load knowledge graph dataset for RGCN link prediction tasks
...@@ -660,9 +769,9 @@ def load_data(dataset): ...@@ -660,9 +769,9 @@ def load_data(dataset):
------ ------
The dataset object. The dataset object.
""" """
if dataset == 'wn18': if dataset == "wn18":
return WN18Dataset() return WN18Dataset()
elif dataset == 'FB15k': elif dataset == "FB15k":
return FB15kDataset() return FB15kDataset()
elif dataset == 'FB15k-237': elif dataset == "FB15k-237":
return FB15k237Dataset() return FB15k237Dataset()
"""QM7b dataset for graph property prediction (regression).""" """QM7b dataset for graph property prediction (regression)."""
from scipy import io
import numpy as np
import os import os
from .dgl_dataset import DGLDataset import numpy as np
from .utils import download, save_graphs, load_graphs, \ from scipy import io
check_sha1, deprecate_property
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from .dgl_dataset import DGLDataset
from .utils import (
check_sha1,
deprecate_property,
download,
load_graphs,
save_graphs,
)
class QM7bDataset(DGLDataset): class QM7bDataset(DGLDataset):
r"""QM7b dataset for graph property prediction (regression) r"""QM7b dataset for graph property prediction (regression)
...@@ -67,57 +74,69 @@ class QM7bDataset(DGLDataset): ...@@ -67,57 +74,69 @@ class QM7bDataset(DGLDataset):
>>> >>>
""" """
_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \ _url = (
'datasets/qm7b.mat' "http://deepchem.io.s3-website-us-west-1.amazonaws.com/"
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392' "datasets/qm7b.mat"
)
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None): _sha1_str = "4102c744bb9d6fd7b40ac67a300e49cd87e28392"
super(QM7bDataset, self).__init__(name='qm7b',
url=self._url, def __init__(
raw_dir=raw_dir, self, raw_dir=None, force_reload=False, verbose=False, transform=None
force_reload=force_reload, ):
verbose=verbose, super(QM7bDataset, self).__init__(
transform=transform) name="qm7b",
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self): def process(self):
mat_path = self.raw_path + '.mat' mat_path = self.raw_path + ".mat"
self.graphs, self.label = self._load_graph(mat_path) self.graphs, self.label = self._load_graph(mat_path)
def _load_graph(self, filename): def _load_graph(self, filename):
data = io.loadmat(filename) data = io.loadmat(filename)
labels = F.tensor(data['T'], dtype=F.data_type_dict['float32']) labels = F.tensor(data["T"], dtype=F.data_type_dict["float32"])
feats = data['X'] feats = data["X"]
num_graphs = labels.shape[0] num_graphs = labels.shape[0]
graphs = [] graphs = []
for i in range(num_graphs): for i in range(num_graphs):
edge_list = feats[i].nonzero() edge_list = feats[i].nonzero()
g = dgl_graph(edge_list) g = dgl_graph(edge_list)
g.edata['h'] = F.tensor(feats[i][edge_list[0], edge_list[1]].reshape(-1, 1), g.edata["h"] = F.tensor(
dtype=F.data_type_dict['float32']) feats[i][edge_list[0], edge_list[1]].reshape(-1, 1),
dtype=F.data_type_dict["float32"],
)
graphs.append(g) graphs.append(g)
return graphs, labels return graphs, labels
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(str(graph_path), self.graphs, {'labels': self.label}) save_graphs(str(graph_path), self.graphs, {"labels": self.label})
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
return os.path.exists(graph_path) return os.path.exists(graph_path)
def load(self): def load(self):
graphs, label_dict = load_graphs(os.path.join(self.save_path, 'dgl_graph.bin')) graphs, label_dict = load_graphs(
os.path.join(self.save_path, "dgl_graph.bin")
)
self.graphs = graphs self.graphs = graphs
self.label = label_dict['labels'] self.label = label_dict["labels"]
def download(self): def download(self):
file_path = os.path.join(self.raw_dir, self.name + '.mat') file_path = os.path.join(self.raw_dir, self.name + ".mat")
download(self.url, path=file_path) download(self.url, path=file_path)
if not check_sha1(file_path, self._sha1_str): if not check_sha1(file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.' raise UserWarning(
'The repo may be outdated or download may be incomplete. ' "File {} is downloaded but the content hash does not match."
'Otherwise you can create an issue for it.'.format(self.name)) "The repo may be outdated or download may be incomplete. "
"Otherwise you can create an issue for it.".format(self.name)
)
@property @property
def num_tasks(self): def num_tasks(self):
...@@ -130,7 +149,7 @@ class QM7bDataset(DGLDataset): ...@@ -130,7 +149,7 @@ class QM7bDataset(DGLDataset):
return 14 return 14
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph and label by index r"""Get graph and label by index
Parameters Parameters
---------- ----------
......
"""QM9 dataset for graph property prediction (regression).""" """QM9 dataset for graph property prediction (regression)."""
import os import os
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
from .dgl_dataset import DGLDataset from .. import backend as F
from .utils import download, _get_dgl_url
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from ..transforms import to_bidirected from ..transforms import to_bidirected
from .. import backend as F
from .dgl_dataset import DGLDataset
from .utils import _get_dgl_url, download
class QM9Dataset(DGLDataset): class QM9Dataset(DGLDataset):
r"""QM9 dataset for graph property prediction (regression) r"""QM9 dataset for graph property prediction (regression)
...@@ -103,39 +106,44 @@ class QM9Dataset(DGLDataset): ...@@ -103,39 +106,44 @@ class QM9Dataset(DGLDataset):
>>> >>>
""" """
def __init__(self, def __init__(
label_keys, self,
cutoff=5.0, label_keys,
raw_dir=None, cutoff=5.0,
force_reload=False, raw_dir=None,
verbose=False, force_reload=False,
transform=None): verbose=False,
transform=None,
):
self.cutoff = cutoff self.cutoff = cutoff
self.label_keys = label_keys self.label_keys = label_keys
self._url = _get_dgl_url('dataset/qm9_eV.npz') self._url = _get_dgl_url("dataset/qm9_eV.npz")
super(QM9Dataset, self).__init__(name='qm9', super(QM9Dataset, self).__init__(
url=self._url, name="qm9",
raw_dir=raw_dir, url=self._url,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
npz_path = f'{self.raw_dir}/qm9_eV.npz' npz_path = f"{self.raw_dir}/qm9_eV.npz"
data_dict = np.load(npz_path, allow_pickle=True) data_dict = np.load(npz_path, allow_pickle=True)
# data_dict['N'] contains the number of atoms in each molecule. # data_dict['N'] contains the number of atoms in each molecule.
# Atomic properties (Z and R) of all molecules are concatenated as single tensors, # Atomic properties (Z and R) of all molecules are concatenated as single tensors,
# so you need this value to select the correct atoms for each molecule. # so you need this value to select the correct atoms for each molecule.
self.N = data_dict['N'] self.N = data_dict["N"]
self.R = data_dict['R'] self.R = data_dict["R"]
self.Z = data_dict['Z'] self.Z = data_dict["Z"]
self.label = np.stack([data_dict[key] for key in self.label_keys], axis=1) self.label = np.stack(
[data_dict[key] for key in self.label_keys], axis=1
)
self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)]) self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
def download(self): def download(self):
file_path = f'{self.raw_dir}/qm9_eV.npz' file_path = f"{self.raw_dir}/qm9_eV.npz"
if not os.path.exists(file_path): if not os.path.exists(file_path):
download(self._url, path=file_path) download(self._url, path=file_path)
...@@ -160,7 +168,7 @@ class QM9Dataset(DGLDataset): ...@@ -160,7 +168,7 @@ class QM9Dataset(DGLDataset):
return self.label.shape[1] return self.label.shape[1]
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph and label by index r"""Get graph and label by index
Parameters Parameters
---------- ----------
...@@ -178,18 +186,22 @@ class QM9Dataset(DGLDataset): ...@@ -178,18 +186,22 @@ class QM9Dataset(DGLDataset):
Tensor Tensor
Property values of molecular graphs Property values of molecular graphs
""" """
label = F.tensor(self.label[idx], dtype=F.data_type_dict['float32']) label = F.tensor(self.label[idx], dtype=F.data_type_dict["float32"])
n_atoms = self.N[idx] n_atoms = self.N[idx]
R = self.R[self.N_cumsum[idx]:self.N_cumsum[idx + 1]] R = self.R[self.N_cumsum[idx] : self.N_cumsum[idx + 1]]
dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1) dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(n_atoms, dtype=np.bool_) adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(
n_atoms, dtype=np.bool_
)
adj = adj.tocoo() adj = adj.tocoo()
u, v = F.tensor(adj.row), F.tensor(adj.col) u, v = F.tensor(adj.row), F.tensor(adj.col)
g = dgl_graph((u, v)) g = dgl_graph((u, v))
g = to_bidirected(g) g = to_bidirected(g)
g.ndata['R'] = F.tensor(R, dtype=F.data_type_dict['float32']) g.ndata["R"] = F.tensor(R, dtype=F.data_type_dict["float32"])
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]], g.ndata["Z"] = F.tensor(
dtype=F.data_type_dict['int64']) self.Z[self.N_cumsum[idx] : self.N_cumsum[idx + 1]],
dtype=F.data_type_dict["int64"],
)
if self._transform is not None: if self._transform is not None:
g = self._transform(g) g = self._transform(g)
...@@ -205,4 +217,5 @@ class QM9Dataset(DGLDataset): ...@@ -205,4 +217,5 @@ class QM9Dataset(DGLDataset):
""" """
return self.label.shape[0] return self.label.shape[0]
QM9 = QM9Dataset
\ No newline at end of file QM9 = QM9Dataset
""" QM9 dataset for graph property prediction (regression) """ """ QM9 dataset for graph property prediction (regression) """
import os import os
import numpy as np import numpy as np
from .dgl_dataset import DGLDataset
from .utils import download, extract_archive, _get_dgl_url
from ..convert import graph as dgl_graph
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph
from .dgl_dataset import DGLDataset
from .utils import _get_dgl_url, download, extract_archive
class QM9EdgeDataset(DGLDataset): class QM9EdgeDataset(DGLDataset):
...@@ -129,19 +131,40 @@ class QM9EdgeDataset(DGLDataset): ...@@ -129,19 +131,40 @@ class QM9EdgeDataset(DGLDataset):
>>> >>>
""" """
keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'U0_atom', keys = [
'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'] "mu",
"alpha",
"homo",
"lumo",
"gap",
"r2",
"zpve",
"U0",
"U",
"H",
"G",
"Cv",
"U0_atom",
"U_atom",
"H_atom",
"G_atom",
"A",
"B",
"C",
]
map_dict = {} map_dict = {}
for i, key in enumerate(keys): for i, key in enumerate(keys):
map_dict[key] = i map_dict[key] = i
def __init__(self, def __init__(
label_keys=None, self,
raw_dir=None, label_keys=None,
force_reload=False, raw_dir=None,
verbose=True, force_reload=False,
transform=None): verbose=True,
transform=None,
):
if label_keys is None: if label_keys is None:
self.label_keys = None self.label_keys = None
self.num_labels = 19 self.num_labels = 19
...@@ -149,18 +172,19 @@ class QM9EdgeDataset(DGLDataset): ...@@ -149,18 +172,19 @@ class QM9EdgeDataset(DGLDataset):
self.label_keys = [self.map_dict[i] for i in label_keys] self.label_keys = [self.map_dict[i] for i in label_keys]
self.num_labels = len(label_keys) self.num_labels = len(label_keys)
self._url = _get_dgl_url("dataset/qm9_edge.npz")
self._url = _get_dgl_url('dataset/qm9_edge.npz') super(QM9EdgeDataset, self).__init__(
name="qm9Edge",
super(QM9EdgeDataset, self).__init__(name='qm9Edge', raw_dir=raw_dir,
raw_dir=raw_dir, url=self._url,
url=self._url, force_reload=force_reload,
force_reload=force_reload, verbose=verbose,
verbose=verbose, transform=transform,
transform=transform) )
def download(self): def download(self):
file_path = f'{self.raw_dir}/qm9_edge.npz' file_path = f"{self.raw_dir}/qm9_edge.npz"
if not os.path.exists(file_path): if not os.path.exists(file_path):
download(self._url, path=file_path) download(self._url, path=file_path)
...@@ -168,39 +192,41 @@ class QM9EdgeDataset(DGLDataset): ...@@ -168,39 +192,41 @@ class QM9EdgeDataset(DGLDataset):
self.load() self.load()
def has_cache(self): def has_cache(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz' npz_path = f"{self.raw_dir}/qm9_edge.npz"
return os.path.exists(npz_path) return os.path.exists(npz_path)
def save(self): def save(self):
np.savez_compressed(f'{self.raw_dir}/qm9_edge.npz', np.savez_compressed(
n_node=self.n_node, f"{self.raw_dir}/qm9_edge.npz",
n_edge=self.n_edge, n_node=self.n_node,
node_attr=self.node_attr, n_edge=self.n_edge,
node_pos=self.node_pos, node_attr=self.node_attr,
edge_attr=self.edge_attr, node_pos=self.node_pos,
src=self.src, edge_attr=self.edge_attr,
dst=self.dst, src=self.src,
targets=self.targets) dst=self.dst,
targets=self.targets,
)
def load(self): def load(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz' npz_path = f"{self.raw_dir}/qm9_edge.npz"
data_dict = np.load(npz_path, allow_pickle=True) data_dict = np.load(npz_path, allow_pickle=True)
self.n_node = data_dict['n_node'] self.n_node = data_dict["n_node"]
self.n_edge = data_dict['n_edge'] self.n_edge = data_dict["n_edge"]
self.node_attr = data_dict['node_attr'] self.node_attr = data_dict["node_attr"]
self.node_pos = data_dict['node_pos'] self.node_pos = data_dict["node_pos"]
self.edge_attr = data_dict['edge_attr'] self.edge_attr = data_dict["edge_attr"]
self.targets = data_dict['targets'] self.targets = data_dict["targets"]
self.src = data_dict['src'] self.src = data_dict["src"]
self.dst = data_dict['dst'] self.dst = data_dict["dst"]
self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)]) self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])
self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)]) self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph and label by index r"""Get graph and label by index
Parameters Parameters
---------- ----------
...@@ -220,18 +246,26 @@ class QM9EdgeDataset(DGLDataset): ...@@ -220,18 +246,26 @@ class QM9EdgeDataset(DGLDataset):
Property values of molecular graphs Property values of molecular graphs
""" """
pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]] pos = self.node_pos[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]
src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]] src = self.src[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]] dst = self.dst[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
g = dgl_graph((src, dst)) g = dgl_graph((src, dst))
g.ndata['pos'] = F.tensor(pos, dtype=F.data_type_dict['float32']) g.ndata["pos"] = F.tensor(pos, dtype=F.data_type_dict["float32"])
g.ndata['attr'] = F.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]], dtype=F.data_type_dict['float32']) g.ndata["attr"] = F.tensor(
g.edata['edge_attr'] = F.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]], dtype=F.data_type_dict['float32']) self.node_attr[self.n_cumsum[idx] : self.n_cumsum[idx + 1]],
dtype=F.data_type_dict["float32"],
)
label = F.tensor(self.targets[idx][self.label_keys], dtype=F.data_type_dict['float32']) g.edata["edge_attr"] = F.tensor(
self.edge_attr[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]],
dtype=F.data_type_dict["float32"],
)
label = F.tensor(
self.targets[idx][self.label_keys],
dtype=F.data_type_dict["float32"],
)
if self._transform is not None: if self._transform is not None:
g = self._transform(g) g = self._transform(g)
...@@ -239,7 +273,7 @@ class QM9EdgeDataset(DGLDataset): ...@@ -239,7 +273,7 @@ class QM9EdgeDataset(DGLDataset):
return g, label return g, label
def __len__(self): def __len__(self):
r""" Number of graphs in the dataset. r"""Number of graphs in the dataset.
Returns Returns
------- -------
...@@ -258,4 +292,4 @@ class QM9EdgeDataset(DGLDataset): ...@@ -258,4 +292,4 @@ class QM9EdgeDataset(DGLDataset):
return self.num_labels return self.num_labels
QM9Edge = QM9EdgeDataset QM9Edge = QM9EdgeDataset
\ No newline at end of file
""" Reddit dataset for community detection """ """ Reddit dataset for community detection """
from __future__ import absolute_import from __future__ import absolute_import
import scipy.sparse as sp
import numpy as np
import os import os
from .dgl_dataset import DGLBuiltinDataset import numpy as np
from .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs, deprecate_property
import scipy.sparse as sp
from .. import backend as F from .. import backend as F
from ..convert import from_scipy from ..convert import from_scipy
from ..transforms import reorder_graph from ..transforms import reorder_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import (
_get_dgl_url,
deprecate_property,
generate_mask_tensor,
load_graphs,
save_graphs,
)
class RedditDataset(DGLBuiltinDataset): class RedditDataset(DGLBuiltinDataset):
r""" Reddit dataset for community detection (node classification) r"""Reddit dataset for community detection (node classification)
This is a graph dataset from Reddit posts made in the month of September, 2014. This is a graph dataset from Reddit posts made in the month of September, 2014.
The node label in this case is the community, or “subreddit”, that a post belongs to. The node label in this case is the community, or “subreddit”, that a post belongs to.
...@@ -73,24 +82,36 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -73,24 +82,36 @@ class RedditDataset(DGLBuiltinDataset):
>>> >>>
>>> # Train, Validation and Test >>> # Train, Validation and Test
""" """
def __init__(self, self_loop=False, raw_dir=None, force_reload=False,
verbose=False, transform=None): def __init__(
self,
self_loop=False,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
self_loop_str = "" self_loop_str = ""
if self_loop: if self_loop:
self_loop_str = "_self_loop" self_loop_str = "_self_loop"
_url = _get_dgl_url("dataset/reddit{}.zip".format(self_loop_str)) _url = _get_dgl_url("dataset/reddit{}.zip".format(self_loop_str))
self._self_loop_str = self_loop_str self._self_loop_str = self_loop_str
super(RedditDataset, self).__init__(name='reddit{}'.format(self_loop_str), super(RedditDataset, self).__init__(
url=_url, name="reddit{}".format(self_loop_str),
raw_dir=raw_dir, url=_url,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
# graph # graph
coo_adj = sp.load_npz(os.path.join( coo_adj = sp.load_npz(
self.raw_path, "reddit{}_graph.npz".format(self._self_loop_str))) os.path.join(
self.raw_path, "reddit{}_graph.npz".format(self._self_loop_str)
)
)
self._graph = from_scipy(coo_adj) self._graph = from_scipy(coo_adj)
# features and labels # features and labels
reddit_data = np.load(os.path.join(self.raw_path, "reddit_data.npz")) reddit_data = np.load(os.path.join(self.raw_path, "reddit_data.npz"))
...@@ -98,48 +119,74 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -98,48 +119,74 @@ class RedditDataset(DGLBuiltinDataset):
labels = reddit_data["label"] labels = reddit_data["label"]
# tarin/val/test indices # tarin/val/test indices
node_types = reddit_data["node_types"] node_types = reddit_data["node_types"]
train_mask = (node_types == 1) train_mask = node_types == 1
val_mask = (node_types == 2) val_mask = node_types == 2
test_mask = (node_types == 3) test_mask = node_types == 3
self._graph.ndata['train_mask'] = generate_mask_tensor(train_mask) self._graph.ndata["train_mask"] = generate_mask_tensor(train_mask)
self._graph.ndata['val_mask'] = generate_mask_tensor(val_mask) self._graph.ndata["val_mask"] = generate_mask_tensor(val_mask)
self._graph.ndata['test_mask'] = generate_mask_tensor(test_mask) self._graph.ndata["test_mask"] = generate_mask_tensor(test_mask)
self._graph.ndata['feat'] = F.tensor(features, dtype=F.data_type_dict['float32']) self._graph.ndata["feat"] = F.tensor(
self._graph.ndata['label'] = F.tensor(labels, dtype=F.data_type_dict['int64']) features, dtype=F.data_type_dict["float32"]
)
self._graph.ndata["label"] = F.tensor(
labels, dtype=F.data_type_dict["int64"]
)
self._graph = reorder_graph( self._graph = reorder_graph(
self._graph, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) self._graph,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
self._print_info() self._print_info()
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
if os.path.exists(graph_path): if os.path.exists(graph_path):
return True return True
return False return False
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(graph_path, self._graph) save_graphs(graph_path, self._graph)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
graphs, _ = load_graphs(graph_path) graphs, _ = load_graphs(graph_path)
self._graph = graphs[0] self._graph = graphs[0]
self._graph.ndata['train_mask'] = generate_mask_tensor(self._graph.ndata['train_mask'].numpy()) self._graph.ndata["train_mask"] = generate_mask_tensor(
self._graph.ndata['val_mask'] = generate_mask_tensor(self._graph.ndata['val_mask'].numpy()) self._graph.ndata["train_mask"].numpy()
self._graph.ndata['test_mask'] = generate_mask_tensor(self._graph.ndata['test_mask'].numpy()) )
self._graph.ndata["val_mask"] = generate_mask_tensor(
self._graph.ndata["val_mask"].numpy()
)
self._graph.ndata["test_mask"] = generate_mask_tensor(
self._graph.ndata["test_mask"].numpy()
)
self._print_info() self._print_info()
def _print_info(self): def _print_info(self):
if self.verbose: if self.verbose:
print('Finished data loading.') print("Finished data loading.")
print(' NumNodes: {}'.format(self._graph.number_of_nodes())) print(" NumNodes: {}".format(self._graph.number_of_nodes()))
print(' NumEdges: {}'.format(self._graph.number_of_edges())) print(" NumEdges: {}".format(self._graph.number_of_edges()))
print(' NumFeats: {}'.format(self._graph.ndata['feat'].shape[1])) print(" NumFeats: {}".format(self._graph.ndata["feat"].shape[1]))
print(' NumClasses: {}'.format(self.num_classes)) print(" NumClasses: {}".format(self.num_classes))
print(' NumTrainingSamples: {}'.format(F.nonzero_1d(self._graph.ndata['train_mask']).shape[0])) print(
print(' NumValidationSamples: {}'.format(F.nonzero_1d(self._graph.ndata['val_mask']).shape[0])) " NumTrainingSamples: {}".format(
print(' NumTestSamples: {}'.format(F.nonzero_1d(self._graph.ndata['test_mask']).shape[0])) F.nonzero_1d(self._graph.ndata["train_mask"]).shape[0]
)
)
print(
" NumValidationSamples: {}".format(
F.nonzero_1d(self._graph.ndata["val_mask"]).shape[0]
)
)
print(
" NumTestSamples: {}".format(
F.nonzero_1d(self._graph.ndata["test_mask"]).shape[0]
)
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -147,7 +194,7 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -147,7 +194,7 @@ class RedditDataset(DGLBuiltinDataset):
return 41 return 41
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph by index r"""Get graph by index
Parameters Parameters
---------- ----------
......
...@@ -4,19 +4,28 @@ Including: ...@@ -4,19 +4,28 @@ Including:
""" """
from __future__ import absolute_import from __future__ import absolute_import
import os
from collections import OrderedDict from collections import OrderedDict
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import os
from .dgl_dataset import DGLBuiltinDataset
from .. import backend as F from .. import backend as F
from .utils import _get_dgl_url, save_graphs, save_info, load_graphs, \
load_info, deprecate_property
from ..convert import from_networkx from ..convert import from_networkx
__all__ = ['SST', 'SSTDataset'] from .dgl_dataset import DGLBuiltinDataset
from .utils import (
_get_dgl_url,
deprecate_property,
load_graphs,
load_info,
save_graphs,
save_info,
)
__all__ = ["SST", "SSTDataset"]
class SSTDataset(DGLBuiltinDataset): class SSTDataset(DGLBuiltinDataset):
...@@ -104,45 +113,58 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -104,45 +113,58 @@ class SSTDataset(DGLBuiltinDataset):
PAD_WORD = -1 # special pad word id PAD_WORD = -1 # special pad word id
UNK_WORD = -1 # out-of-vocabulary word id UNK_WORD = -1 # out-of-vocabulary word id
def __init__(self, def __init__(
mode='train', self,
glove_embed_file=None, mode="train",
vocab_file=None, glove_embed_file=None,
raw_dir=None, vocab_file=None,
force_reload=False, raw_dir=None,
verbose=False, force_reload=False,
transform=None): verbose=False,
assert mode in ['train', 'dev', 'test', 'tiny'] transform=None,
_url = _get_dgl_url('dataset/sst.zip') ):
self._glove_embed_file = glove_embed_file if mode == 'train' else None assert mode in ["train", "dev", "test", "tiny"]
_url = _get_dgl_url("dataset/sst.zip")
self._glove_embed_file = glove_embed_file if mode == "train" else None
self.mode = mode self.mode = mode
self._vocab_file = vocab_file self._vocab_file = vocab_file
super(SSTDataset, self).__init__(name='sst', super(SSTDataset, self).__init__(
url=_url, name="sst",
raw_dir=raw_dir, url=_url,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
from nltk.corpus.reader import BracketParseCorpusReader from nltk.corpus.reader import BracketParseCorpusReader
# load vocab file # load vocab file
self._vocab = OrderedDict() self._vocab = OrderedDict()
vocab_file = self._vocab_file if self._vocab_file is not None else os.path.join(self.raw_path, 'vocab.txt') vocab_file = (
with open(vocab_file, encoding='utf-8') as vf: self._vocab_file
if self._vocab_file is not None
else os.path.join(self.raw_path, "vocab.txt")
)
with open(vocab_file, encoding="utf-8") as vf:
for line in vf.readlines(): for line in vf.readlines():
line = line.strip() line = line.strip()
self._vocab[line] = len(self._vocab) self._vocab[line] = len(self._vocab)
# filter glove # filter glove
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): if self._glove_embed_file is not None and os.path.exists(
self._glove_embed_file
):
glove_emb = {} glove_emb = {}
with open(self._glove_embed_file, 'r', encoding='utf-8') as pf: with open(self._glove_embed_file, "r", encoding="utf-8") as pf:
for line in pf.readlines(): for line in pf.readlines():
sp = line.split(' ') sp = line.split(" ")
if sp[0].lower() in self._vocab: if sp[0].lower() in self._vocab:
glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]]) glove_emb[sp[0].lower()] = np.asarray(
files = ['{}.txt'.format(self.mode)] [float(x) for x in sp[1:]]
)
files = ["{}.txt".format(self.mode)]
corpus = BracketParseCorpusReader(self.raw_path, files) corpus = BracketParseCorpusReader(self.raw_path, files)
sents = corpus.parsed_sents(files[0]) sents = corpus.parsed_sents(files[0])
...@@ -150,15 +172,27 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -150,15 +172,27 @@ class SSTDataset(DGLBuiltinDataset):
pretrained_emb = [] pretrained_emb = []
fail_cnt = 0 fail_cnt = 0
for line in self._vocab.keys(): for line in self._vocab.keys():
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): if self._glove_embed_file is not None and os.path.exists(
self._glove_embed_file
):
if not line.lower() in glove_emb: if not line.lower() in glove_emb:
fail_cnt += 1 fail_cnt += 1
pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300))) pretrained_emb.append(
glove_emb.get(
line.lower(), np.random.uniform(-0.05, 0.05, 300)
)
)
self._pretrained_emb = None self._pretrained_emb = None
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): if self._glove_embed_file is not None and os.path.exists(
self._glove_embed_file
):
self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0)) self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
print('Miss word in GloVe {0:.4f}'.format(1.0 * fail_cnt / len(self._pretrained_emb))) print(
"Miss word in GloVe {0:.4f}".format(
1.0 * fail_cnt / len(self._pretrained_emb)
)
)
# build trees # build trees
self._trees = [] self._trees = []
for sent in sents: for sent in sents:
...@@ -175,44 +209,46 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -175,44 +209,46 @@ class SSTDataset(DGLBuiltinDataset):
word = self.vocab.get(child[0].lower(), self.UNK_WORD) word = self.vocab.get(child[0].lower(), self.UNK_WORD)
g.add_node(cid, x=word, y=int(child.label()), mask=1) g.add_node(cid, x=word, y=int(child.label()), mask=1)
else: else:
g.add_node(cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0) g.add_node(
cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0
)
_rec_build(cid, child) _rec_build(cid, child)
g.add_edge(cid, nid) g.add_edge(cid, nid)
# add root # add root
g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0) g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0)
_rec_build(0, root) _rec_build(0, root)
ret = from_networkx(g, node_attrs=['x', 'y', 'mask']) ret = from_networkx(g, node_attrs=["x", "y", "mask"])
return ret return ret
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
vocab_path = os.path.join(self.save_path, 'vocab.pkl') vocab_path = os.path.join(self.save_path, "vocab.pkl")
return os.path.exists(graph_path) and os.path.exists(vocab_path) return os.path.exists(graph_path) and os.path.exists(vocab_path)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
save_graphs(graph_path, self._trees) save_graphs(graph_path, self._trees)
vocab_path = os.path.join(self.save_path, 'vocab.pkl') vocab_path = os.path.join(self.save_path, "vocab.pkl")
save_info(vocab_path, {'vocab': self.vocab}) save_info(vocab_path, {"vocab": self.vocab})
if self.pretrained_emb: if self.pretrained_emb:
emb_path = os.path.join(self.save_path, 'emb.pkl') emb_path = os.path.join(self.save_path, "emb.pkl")
save_info(emb_path, {'embed': self.pretrained_emb}) save_info(emb_path, {"embed": self.pretrained_emb})
def load(self): def load(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
vocab_path = os.path.join(self.save_path, 'vocab.pkl') vocab_path = os.path.join(self.save_path, "vocab.pkl")
emb_path = os.path.join(self.save_path, 'emb.pkl') emb_path = os.path.join(self.save_path, "emb.pkl")
self._trees = load_graphs(graph_path)[0] self._trees = load_graphs(graph_path)[0]
self._vocab = load_info(vocab_path)['vocab'] self._vocab = load_info(vocab_path)["vocab"]
self._pretrained_emb = None self._pretrained_emb = None
if os.path.exists(emb_path): if os.path.exists(emb_path):
self._pretrained_emb = load_info(emb_path)['embed'] self._pretrained_emb = load_info(emb_path)["embed"]
@property @property
def vocab(self): def vocab(self):
r""" Vocabulary r"""Vocabulary
Returns Returns
------- -------
...@@ -226,7 +262,7 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -226,7 +262,7 @@ class SSTDataset(DGLBuiltinDataset):
return self._pretrained_emb return self._pretrained_emb
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph by index r"""Get graph by index
Parameters Parameters
---------- ----------
......
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np
import os import os
from .dgl_dataset import DGLBuiltinDataset import numpy as np
from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import load_graphs, load_info, loadtxt, save_graphs, save_info
class LegacyTUDataset(DGLBuiltinDataset): class LegacyTUDataset(DGLBuiltinDataset):
r"""LegacyTUDataset contains lots of graph kernel datasets for graph classification. r"""LegacyTUDataset contains lots of graph kernel datasets for graph classification.
...@@ -77,38 +81,60 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -77,38 +81,60 @@ class LegacyTUDataset(DGLBuiltinDataset):
_url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip" _url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
def __init__(self, name, use_pandas=False, def __init__(
hidden_size=10, max_allow_node=None, self,
raw_dir=None, force_reload=False, verbose=False, transform=None): name,
use_pandas=False,
hidden_size=10,
max_allow_node=None,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
url = self._url.format(name) url = self._url.format(name)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.max_allow_node = max_allow_node self.max_allow_node = max_allow_node
self.use_pandas = use_pandas self.use_pandas = use_pandas
super(LegacyTUDataset, self).__init__(name=name, url=url, raw_dir=raw_dir, super(LegacyTUDataset, self).__init__(
hash_key=(name, use_pandas, hidden_size, max_allow_node), name=name,
force_reload=force_reload, verbose=verbose, transform=transform) url=url,
raw_dir=raw_dir,
hash_key=(name, use_pandas, hidden_size, max_allow_node),
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self): def process(self):
self.data_mode = None self.data_mode = None
if self.use_pandas: if self.use_pandas:
import pandas as pd import pandas as pd
DS_edge_list = self._idx_from_zero( DS_edge_list = self._idx_from_zero(
pd.read_csv(self._file_path("A"), delimiter=",", dtype=int, header=None).values) pd.read_csv(
self._file_path("A"), delimiter=",", dtype=int, header=None
).values
)
else: else:
DS_edge_list = self._idx_from_zero( DS_edge_list = self._idx_from_zero(
np.genfromtxt(self._file_path("A"), delimiter=",", dtype=int)) np.genfromtxt(self._file_path("A"), delimiter=",", dtype=int)
)
DS_indicator = self._idx_from_zero( DS_indicator = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_indicator"), dtype=int)) np.genfromtxt(self._file_path("graph_indicator"), dtype=int)
)
if os.path.exists(self._file_path("graph_labels")): if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_from_zero( DS_graph_labels = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_labels"), dtype=int)) np.genfromtxt(self._file_path("graph_labels"), dtype=int)
)
self.num_labels = max(DS_graph_labels) + 1 self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels self.graph_labels = DS_graph_labels
elif os.path.exists(self._file_path("graph_attributes")): elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = np.genfromtxt(self._file_path("graph_attributes"), dtype=float) DS_graph_labels = np.genfromtxt(
self._file_path("graph_attributes"), dtype=float
)
self.num_labels = None self.num_labels = None
self.graph_labels = DS_graph_labels self.graph_labels = DS_graph_labels
else: else:
...@@ -130,33 +156,42 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -130,33 +156,42 @@ class LegacyTUDataset(DGLBuiltinDataset):
try: try:
DS_node_labels = self._idx_from_zero( DS_node_labels = self._idx_from_zero(
np.loadtxt(self._file_path("node_labels"), dtype=int)) np.loadtxt(self._file_path("node_labels"), dtype=int)
g.ndata['node_label'] = F.tensor(DS_node_labels) )
g.ndata["node_label"] = F.tensor(DS_node_labels)
one_hot_node_labels = self._to_onehot(DS_node_labels) one_hot_node_labels = self._to_onehot(DS_node_labels)
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = F.tensor(one_hot_node_labels[idxs, :], F.float32) g.ndata["feat"] = F.tensor(
one_hot_node_labels[idxs, :], F.float32
)
self.data_mode = "node_label" self.data_mode = "node_label"
except IOError: except IOError:
print("No Node Label Data") print("No Node Label Data")
try: try:
DS_node_attr = np.loadtxt( DS_node_attr = np.loadtxt(
self._file_path("node_attributes"), delimiter=",") self._file_path("node_attributes"), delimiter=","
)
if DS_node_attr.ndim == 1: if DS_node_attr.ndim == 1:
DS_node_attr = np.expand_dims(DS_node_attr, -1) DS_node_attr = np.expand_dims(DS_node_attr, -1)
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = F.tensor(DS_node_attr[idxs, :], F.float32) g.ndata["feat"] = F.tensor(DS_node_attr[idxs, :], F.float32)
self.data_mode = "node_attr" self.data_mode = "node_attr"
except IOError: except IOError:
print("No Node Attribute Data") print("No Node Attribute Data")
if 'feat' not in g.ndata.keys(): if "feat" not in g.ndata.keys():
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = F.ones((g.number_of_nodes(), self.hidden_size), g.ndata["feat"] = F.ones(
F.float32, F.cpu()) (g.number_of_nodes(), self.hidden_size), F.float32, F.cpu()
)
self.data_mode = "constant" self.data_mode = "constant"
if self.verbose: if self.verbose:
print("Use Constant one as Feature with hidden size {}".format(self.hidden_size)) print(
"Use Constant one as Feature with hidden size {}".format(
self.hidden_size
)
)
# remove graphs that are too large by user given standard # remove graphs that are too large by user given standard
# optional pre-processing steop in conformity with Rex Ying's original # optional pre-processing steop in conformity with Rex Ying's original
...@@ -165,39 +200,56 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -165,39 +200,56 @@ class LegacyTUDataset(DGLBuiltinDataset):
preserve_idx = [] preserve_idx = []
if self.verbose: if self.verbose:
print("original dataset length : ", len(self.graph_lists)) print("original dataset length : ", len(self.graph_lists))
for (i, g) in enumerate(self.graph_lists): for i, g in enumerate(self.graph_lists):
if g.number_of_nodes() <= self.max_allow_node: if g.number_of_nodes() <= self.max_allow_node:
preserve_idx.append(i) preserve_idx.append(i)
self.graph_lists = [self.graph_lists[i] for i in preserve_idx] self.graph_lists = [self.graph_lists[i] for i in preserve_idx]
if self.verbose: if self.verbose:
print("after pruning graphs that are too big : ", len(self.graph_lists)) print(
"after pruning graphs that are too big : ",
len(self.graph_lists),
)
self.graph_labels = [self.graph_labels[i] for i in preserve_idx] self.graph_labels = [self.graph_labels[i] for i in preserve_idx]
self.max_num_node = self.max_allow_node self.max_num_node = self.max_allow_node
self.graph_labels = F.tensor(self.graph_labels) self.graph_labels = F.tensor(self.graph_labels)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.bin'.format(self.name, self.hash)) graph_path = os.path.join(
info_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
label_dict = {'labels': self.graph_labels} )
info_dict = {'max_num_node': self.max_num_node, info_path = os.path.join(
'num_labels': self.num_labels} self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
label_dict = {"labels": self.graph_labels}
info_dict = {
"max_num_node": self.max_num_node,
"num_labels": self.num_labels,
}
save_graphs(str(graph_path), self.graph_lists, label_dict) save_graphs(str(graph_path), self.graph_lists, label_dict)
save_info(str(info_path), info_dict) save_info(str(info_path), info_dict)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.bin'.format(self.name, self.hash)) graph_path = os.path.join(
info_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
graphs, label_dict = load_graphs(str(graph_path)) graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path)) info_dict = load_info(str(info_path))
self.graph_lists = graphs self.graph_lists = graphs
self.graph_labels = label_dict['labels'] self.graph_labels = label_dict["labels"]
self.max_num_node = info_dict['max_num_node'] self.max_num_node = info_dict["max_num_node"]
self.num_labels = info_dict['num_labels'] self.num_labels = info_dict["num_labels"]
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.bin'.format(self.name, self.hash)) graph_path = os.path.join(
info_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
if os.path.exists(graph_path) and os.path.exists(info_path): if os.path.exists(graph_path) and os.path.exists(info_path):
return True return True
return False return False
...@@ -226,8 +278,9 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -226,8 +278,9 @@ class LegacyTUDataset(DGLBuiltinDataset):
return len(self.graph_lists) return len(self.graph_lists)
def _file_path(self, category): def _file_path(self, category):
return os.path.join(self.raw_path, self.name, return os.path.join(
"{}_{}.txt".format(self.name, category)) self.raw_path, self.name, "{}_{}.txt".format(self.name, category)
)
@staticmethod @staticmethod
def _idx_from_zero(idx_tensor): def _idx_from_zero(idx_tensor):
...@@ -242,14 +295,17 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -242,14 +295,17 @@ class LegacyTUDataset(DGLBuiltinDataset):
return one_hot_tensor return one_hot_tensor
def statistics(self): def statistics(self):
return self.graph_lists[0].ndata['feat'].shape[1],\ return (
self.num_labels,\ self.graph_lists[0].ndata["feat"].shape[1],
self.max_num_node self.num_labels,
self.max_num_node,
)
@property @property
def num_classes(self): def num_classes(self):
return int(self.num_labels) return int(self.num_labels)
class TUDataset(DGLBuiltinDataset): class TUDataset(DGLBuiltinDataset):
r""" r"""
TUDataset contains lots of graph kernel datasets for graph classification. TUDataset contains lots of graph kernel datasets for graph classification.
...@@ -322,25 +378,46 @@ class TUDataset(DGLBuiltinDataset): ...@@ -322,25 +378,46 @@ class TUDataset(DGLBuiltinDataset):
_url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip" _url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None): def __init__(
self,
name,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
url = self._url.format(name) url = self._url.format(name)
super(TUDataset, self).__init__(name=name, url=url, super(TUDataset, self).__init__(
raw_dir=raw_dir, force_reload=force_reload, name=name,
verbose=verbose, transform=transform) url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self): def process(self):
DS_edge_list = self._idx_from_zero( DS_edge_list = self._idx_from_zero(
loadtxt(self._file_path("A"), delimiter=",").astype(int)) loadtxt(self._file_path("A"), delimiter=",").astype(int)
)
DS_indicator = self._idx_from_zero( DS_indicator = self._idx_from_zero(
loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(
int
)
)
if os.path.exists(self._file_path("graph_labels")): if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_reset( DS_graph_labels = self._idx_reset(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_labels"), delimiter=",").astype(
int
)
)
self.num_labels = int(max(DS_graph_labels) + 1) self.num_labels = int(max(DS_graph_labels) + 1)
self.graph_labels = F.tensor(DS_graph_labels) self.graph_labels = F.tensor(DS_graph_labels)
elif os.path.exists(self._file_path("graph_attributes")): elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = loadtxt(self._file_path("graph_attributes"), delimiter=",").astype(float) DS_graph_labels = loadtxt(
self._file_path("graph_attributes"), delimiter=","
).astype(float)
self.num_labels = None self.num_labels = None
self.graph_labels = F.tensor(DS_graph_labels) self.graph_labels = F.tensor(DS_graph_labels)
else: else:
...@@ -358,19 +435,17 @@ class TUDataset(DGLBuiltinDataset): ...@@ -358,19 +435,17 @@ class TUDataset(DGLBuiltinDataset):
if len(node_idx[0]) > self.max_num_node: if len(node_idx[0]) > self.max_num_node:
self.max_num_node = len(node_idx[0]) self.max_num_node = len(node_idx[0])
self.attr_dict = { self.attr_dict = {
'node_labels': ('ndata', 'node_labels'), "node_labels": ("ndata", "node_labels"),
'node_attributes': ('ndata', 'node_attr'), "node_attributes": ("ndata", "node_attr"),
'edge_labels': ('edata', 'edge_labels'), "edge_labels": ("edata", "edge_labels"),
'edge_attributes': ('edata', 'node_labels'), "edge_attributes": ("edata", "node_labels"),
} }
for filename, field_name in self.attr_dict.items(): for filename, field_name in self.attr_dict.items():
try: try:
data = loadtxt(self._file_path(filename), data = loadtxt(self._file_path(filename), delimiter=",")
delimiter=',') if "label" in filename:
if 'label' in filename:
data = F.tensor(self._idx_from_zero(data)) data = F.tensor(self._idx_from_zero(data))
else: else:
data = F.tensor(data) data = F.tensor(data)
...@@ -381,28 +456,30 @@ class TUDataset(DGLBuiltinDataset): ...@@ -381,28 +456,30 @@ class TUDataset(DGLBuiltinDataset):
self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list] self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list]
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'tu_{}.bin'.format(self.name)) graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name))
info_path = os.path.join(self.save_path, 'tu_{}.pkl'.format(self.name)) info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
label_dict = {'labels': self.graph_labels} label_dict = {"labels": self.graph_labels}
info_dict = {'max_num_node': self.max_num_node, info_dict = {
'num_labels': self.num_labels} "max_num_node": self.max_num_node,
"num_labels": self.num_labels,
}
save_graphs(str(graph_path), self.graph_lists, label_dict) save_graphs(str(graph_path), self.graph_lists, label_dict)
save_info(str(info_path), info_dict) save_info(str(info_path), info_dict)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'tu_{}.bin'.format(self.name)) graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name))
info_path = os.path.join(self.save_path, 'tu_{}.pkl'.format(self.name)) info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
graphs, label_dict = load_graphs(str(graph_path)) graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path)) info_dict = load_info(str(info_path))
self.graph_lists = graphs self.graph_lists = graphs
self.graph_labels = label_dict['labels'] self.graph_labels = label_dict["labels"]
self.max_num_node = info_dict['max_num_node'] self.max_num_node = info_dict["max_num_node"]
self.num_labels = info_dict['num_labels'] self.num_labels = info_dict["num_labels"]
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'tu_{}.bin'.format(self.name)) graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name))
info_path = os.path.join(self.save_path, 'tu_{}.pkl'.format(self.name)) info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
if os.path.exists(graph_path) and os.path.exists(info_path): if os.path.exists(graph_path) and os.path.exists(info_path):
return True return True
return False return False
...@@ -431,8 +508,9 @@ class TUDataset(DGLBuiltinDataset): ...@@ -431,8 +508,9 @@ class TUDataset(DGLBuiltinDataset):
return len(self.graph_lists) return len(self.graph_lists)
def _file_path(self, category): def _file_path(self, category):
return os.path.join(self.raw_path, self.name, return os.path.join(
"{}_{}.txt".format(self.name, category)) self.raw_path, self.name, "{}_{}.txt".format(self.name, category)
)
@staticmethod @staticmethod
def _idx_from_zero(idx_tensor): def _idx_from_zero(idx_tensor):
...@@ -447,9 +525,11 @@ class TUDataset(DGLBuiltinDataset): ...@@ -447,9 +525,11 @@ class TUDataset(DGLBuiltinDataset):
return new_idx_tensor return new_idx_tensor
def statistics(self): def statistics(self):
return self.graph_lists[0].ndata['feat'].shape[1], \ return (
self.num_labels, \ self.graph_lists[0].ndata["feat"].shape[1],
self.max_num_node self.num_labels,
self.max_num_node,
)
@property @property
def num_classes(self): def num_classes(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment