"src/vscode:/vscode.git/clone" did not exist on "80fb4dbe2675adfb2bd469260e20facdaae0631d"
Unverified Commit ab812179 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto reformat data/. (#5318)



* data

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent f0759a96
This diff is collapsed.
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
from .. import backend as F from .. import backend as F
from ..base import DGLError from ..base import DGLError
from .dgl_dataset import DGLDataset from .dgl_dataset import DGLDataset
from .utils import Subset, load_graphs, save_graphs from .utils import load_graphs, save_graphs, Subset
class CSVDataset(DGLDataset): class CSVDataset(DGLDataset):
......
...@@ -8,7 +8,7 @@ import pydantic as dt ...@@ -8,7 +8,7 @@ import pydantic as dt
import yaml import yaml
from .. import backend as F from .. import backend as F
from ..base import DGLError, dgl_warning from ..base import dgl_warning, DGLError
from ..convert import heterograph as dgl_heterograph from ..convert import heterograph as dgl_heterograph
......
...@@ -99,7 +99,6 @@ class GINDataset(DGLBuiltinDataset): ...@@ -99,7 +99,6 @@ class GINDataset(DGLBuiltinDataset):
verbose=False, verbose=False,
transform=None, transform=None,
): ):
self._name = name # MUTAG self._name = name # MUTAG
gin_url = "https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip" gin_url = "https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip"
self.ds_name = "nig" self.ds_name = "nig"
......
"""GNN Benchmark datasets for node classification.""" """GNN Benchmark datasets for node classification."""
import scipy.sparse as sp
import numpy as np
import os import os
from .dgl_dataset import DGLBuiltinDataset import numpy as np
from .utils import save_graphs, load_graphs, _get_dgl_url, deprecate_property, deprecate_class import scipy.sparse as sp
from .. import backend as F, transforms
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from .. import backend as F
from .. import transforms
__all__ = ["AmazonCoBuyComputerDataset", "AmazonCoBuyPhotoDataset", "CoauthorPhysicsDataset", "CoauthorCSDataset", from .dgl_dataset import DGLBuiltinDataset
"CoraFullDataset", "AmazonCoBuy", "Coauthor", "CoraFull"] from .utils import (
_get_dgl_url,
deprecate_class,
deprecate_property,
load_graphs,
save_graphs,
)
__all__ = [
"AmazonCoBuyComputerDataset",
"AmazonCoBuyPhotoDataset",
"CoauthorPhysicsDataset",
"CoauthorCSDataset",
"CoraFullDataset",
"AmazonCoBuy",
"Coauthor",
"CoraFull",
]
def eliminate_self_loops(A): def eliminate_self_loops(A):
...@@ -27,36 +42,50 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -27,36 +42,50 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
Reference: https://github.com/shchur/gnn-benchmark#datasets Reference: https://github.com/shchur/gnn-benchmark#datasets
""" """
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None):
_url = _get_dgl_url('dataset/' + name + '.zip') def __init__(
super(GNNBenchmarkDataset, self).__init__(name=name, self,
name,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
_url = _get_dgl_url("dataset/" + name + ".zip")
super(GNNBenchmarkDataset, self).__init__(
name=name,
url=_url, url=_url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
def process(self): def process(self):
npz_path = os.path.join(self.raw_path, self.name + '.npz') npz_path = os.path.join(self.raw_path, self.name + ".npz")
g = self._load_npz(npz_path) g = self._load_npz(npz_path)
g = transforms.reorder_graph( g = transforms.reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
self._graph = g self._graph = g
self._data = [g] self._data = [g]
self._print_info() self._print_info()
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin') graph_path = os.path.join(self.save_path, "dgl_graph_v1.bin")
if os.path.exists(graph_path): if os.path.exists(graph_path):
return True return True
return False return False
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin') graph_path = os.path.join(self.save_path, "dgl_graph_v1.bin")
save_graphs(graph_path, self._graph) save_graphs(graph_path, self._graph)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin') graph_path = os.path.join(self.save_path, "dgl_graph_v1.bin")
graphs, _ = load_graphs(graph_path) graphs, _ = load_graphs(graph_path)
self._graph = graphs[0] self._graph = graphs[0]
self._data = [graphs[0]] self._data = [graphs[0]]
...@@ -64,41 +93,59 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -64,41 +93,59 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
def _print_info(self): def _print_info(self):
if self.verbose: if self.verbose:
print(' NumNodes: {}'.format(self._graph.number_of_nodes())) print(" NumNodes: {}".format(self._graph.number_of_nodes()))
print(' NumEdges: {}'.format(self._graph.number_of_edges())) print(" NumEdges: {}".format(self._graph.number_of_edges()))
print(' NumFeats: {}'.format(self._graph.ndata['feat'].shape[-1])) print(" NumFeats: {}".format(self._graph.ndata["feat"].shape[-1]))
print(' NumbClasses: {}'.format(self.num_classes)) print(" NumbClasses: {}".format(self.num_classes))
def _load_npz(self, file_name): def _load_npz(self, file_name):
with np.load(file_name, allow_pickle=True) as loader: with np.load(file_name, allow_pickle=True) as loader:
loader = dict(loader) loader = dict(loader)
num_nodes = loader['adj_shape'][0] num_nodes = loader["adj_shape"][0]
adj_matrix = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], loader['adj_indptr']), adj_matrix = sp.csr_matrix(
shape=loader['adj_shape']).tocoo() (
loader["adj_data"],
if 'attr_data' in loader: loader["adj_indices"],
loader["adj_indptr"],
),
shape=loader["adj_shape"],
).tocoo()
if "attr_data" in loader:
# Attributes are stored as a sparse CSR matrix # Attributes are stored as a sparse CSR matrix
attr_matrix = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], loader['attr_indptr']), attr_matrix = sp.csr_matrix(
shape=loader['attr_shape']).todense() (
elif 'attr_matrix' in loader: loader["attr_data"],
loader["attr_indices"],
loader["attr_indptr"],
),
shape=loader["attr_shape"],
).todense()
elif "attr_matrix" in loader:
# Attributes are stored as a (dense) np.ndarray # Attributes are stored as a (dense) np.ndarray
attr_matrix = loader['attr_matrix'] attr_matrix = loader["attr_matrix"]
else: else:
attr_matrix = None attr_matrix = None
if 'labels_data' in loader: if "labels_data" in loader:
# Labels are stored as a CSR matrix # Labels are stored as a CSR matrix
labels = sp.csr_matrix((loader['labels_data'], loader['labels_indices'], loader['labels_indptr']), labels = sp.csr_matrix(
shape=loader['labels_shape']).todense() (
elif 'labels' in loader: loader["labels_data"],
loader["labels_indices"],
loader["labels_indptr"],
),
shape=loader["labels_shape"],
).todense()
elif "labels" in loader:
# Labels are stored as a numpy array # Labels are stored as a numpy array
labels = loader['labels'] labels = loader["labels"]
else: else:
labels = None labels = None
g = dgl_graph((adj_matrix.row, adj_matrix.col)) g = dgl_graph((adj_matrix.row, adj_matrix.col))
g = transforms.to_bidirected(g) g = transforms.to_bidirected(g)
g.ndata['feat'] = F.tensor(attr_matrix, F.data_type_dict['float32']) g.ndata["feat"] = F.tensor(attr_matrix, F.data_type_dict["float32"])
g.ndata['label'] = F.tensor(labels, F.data_type_dict['int64']) g.ndata["label"] = F.tensor(labels, F.data_type_dict["int64"])
return g return g
@property @property
...@@ -107,7 +154,7 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -107,7 +154,7 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
raise NotImplementedError raise NotImplementedError
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph by index r"""Get graph by index
Parameters Parameters
---------- ----------
...@@ -176,12 +223,17 @@ class CoraFullDataset(GNNBenchmarkDataset): ...@@ -176,12 +223,17 @@ class CoraFullDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoraFullDataset, self).__init__(name="cora_full", def __init__(
self, raw_dir=None, force_reload=False, verbose=False, transform=None
):
super(CoraFullDataset, self).__init__(
name="cora_full",
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -195,7 +247,7 @@ class CoraFullDataset(GNNBenchmarkDataset): ...@@ -195,7 +247,7 @@ class CoraFullDataset(GNNBenchmarkDataset):
class CoauthorCSDataset(GNNBenchmarkDataset): class CoauthorCSDataset(GNNBenchmarkDataset):
r""" 'Computer Science (CS)' part of the Coauthor dataset for node classification task. r"""'Computer Science (CS)' part of the Coauthor dataset for node classification task.
Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph
from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they
...@@ -239,12 +291,17 @@ class CoauthorCSDataset(GNNBenchmarkDataset): ...@@ -239,12 +291,17 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoauthorCSDataset, self).__init__(name='coauthor_cs', def __init__(
self, raw_dir=None, force_reload=False, verbose=False, transform=None
):
super(CoauthorCSDataset, self).__init__(
name="coauthor_cs",
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -258,7 +315,7 @@ class CoauthorCSDataset(GNNBenchmarkDataset): ...@@ -258,7 +315,7 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
class CoauthorPhysicsDataset(GNNBenchmarkDataset): class CoauthorPhysicsDataset(GNNBenchmarkDataset):
r""" 'Physics' part of the Coauthor dataset for node classification task. r"""'Physics' part of the Coauthor dataset for node classification task.
Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph
from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they
...@@ -302,12 +359,17 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset): ...@@ -302,12 +359,17 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoauthorPhysicsDataset, self).__init__(name='coauthor_physics', def __init__(
self, raw_dir=None, force_reload=False, verbose=False, transform=None
):
super(CoauthorPhysicsDataset, self).__init__(
name="coauthor_physics",
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -321,7 +383,7 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset): ...@@ -321,7 +383,7 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
class AmazonCoBuyComputerDataset(GNNBenchmarkDataset): class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
r""" 'Computer' part of the AmazonCoBuy dataset for node classification task. r"""'Computer' part of the AmazonCoBuy dataset for node classification task.
Amazon Computers and Amazon Photo are segments of the Amazon co-purchase graph [McAuley et al., 2015], Amazon Computers and Amazon Photo are segments of the Amazon co-purchase graph [McAuley et al., 2015],
where nodes represent goods, edges indicate that two goods are frequently bought together, node where nodes represent goods, edges indicate that two goods are frequently bought together, node
...@@ -364,12 +426,17 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset): ...@@ -364,12 +426,17 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(AmazonCoBuyComputerDataset, self).__init__(name='amazon_co_buy_computer', def __init__(
self, raw_dir=None, force_reload=False, verbose=False, transform=None
):
super(AmazonCoBuyComputerDataset, self).__init__(
name="amazon_co_buy_computer",
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -426,12 +493,17 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset): ...@@ -426,12 +493,17 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature >>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels >>> label = g.ndata['label'] # get node labels
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(AmazonCoBuyPhotoDataset, self).__init__(name='amazon_co_buy_photo', def __init__(
self, raw_dir=None, force_reload=False, verbose=False, transform=None
):
super(AmazonCoBuyPhotoDataset, self).__init__(
name="amazon_co_buy_photo",
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -446,27 +518,27 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset): ...@@ -446,27 +518,27 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
class CoraFull(CoraFullDataset): class CoraFull(CoraFullDataset):
def __init__(self, **kwargs): def __init__(self, **kwargs):
deprecate_class('CoraFull', 'CoraFullDataset') deprecate_class("CoraFull", "CoraFullDataset")
super(CoraFull, self).__init__(**kwargs) super(CoraFull, self).__init__(**kwargs)
def AmazonCoBuy(name): def AmazonCoBuy(name):
if name == 'computers': if name == "computers":
deprecate_class('AmazonCoBuy', 'AmazonCoBuyComputerDataset') deprecate_class("AmazonCoBuy", "AmazonCoBuyComputerDataset")
return AmazonCoBuyComputerDataset() return AmazonCoBuyComputerDataset()
elif name == 'photo': elif name == "photo":
deprecate_class('AmazonCoBuy', 'AmazonCoBuyPhotoDataset') deprecate_class("AmazonCoBuy", "AmazonCoBuyPhotoDataset")
return AmazonCoBuyPhotoDataset() return AmazonCoBuyPhotoDataset()
else: else:
raise ValueError('Dataset name should be "computers" or "photo".') raise ValueError('Dataset name should be "computers" or "photo".')
def Coauthor(name): def Coauthor(name):
if name == 'cs': if name == "cs":
deprecate_class('Coauthor', 'CoauthorCSDataset') deprecate_class("Coauthor", "CoauthorCSDataset")
return CoauthorCSDataset() return CoauthorCSDataset()
elif name == 'physics': elif name == "physics":
deprecate_class('Coauthor', 'CoauthorPhysicsDataset') deprecate_class("Coauthor", "CoauthorPhysicsDataset")
return CoauthorPhysicsDataset() return CoauthorPhysicsDataset()
else: else:
raise ValueError('Dataset name should be "cs" or "physics".') raise ValueError('Dataset name should be "cs" or "physics".')
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
from ..base import DGLError, dgl_warning from ..base import dgl_warning, DGLError
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from .heterograph_serialize import save_heterographs from .heterograph_serialize import save_heterographs
...@@ -58,14 +58,12 @@ class GraphData(ObjectBase): ...@@ -58,14 +58,12 @@ class GraphData(ObjectBase):
node_tensors[key] = F.zerocopy_to_dgl_ndarray(value) node_tensors[key] = F.zerocopy_to_dgl_ndarray(value)
else: else:
node_tensors = None node_tensors = None
if len(g.edata) != 0: if len(g.edata) != 0:
edge_tensors = dict() edge_tensors = dict()
for key, value in g.edata.items(): for key, value in g.edata.items():
edge_tensors[key] = F.zerocopy_to_dgl_ndarray(value) edge_tensors[key] = F.zerocopy_to_dgl_ndarray(value)
else: else:
edge_tensors = None edge_tensors = None
return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors) return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors)
def get_graph(self): def get_graph(self):
...@@ -139,11 +137,8 @@ def save_graphs(filename, g_list, labels=None, formats=None): ...@@ -139,11 +137,8 @@ def save_graphs(filename, g_list, labels=None, formats=None):
f_path = os.path.dirname(filename) f_path = os.path.dirname(filename)
if f_path and not os.path.exists(f_path): if f_path and not os.path.exists(f_path):
os.makedirs(f_path) os.makedirs(f_path)
g_sample = g_list[0] if isinstance(g_list, list) else g_list g_sample = g_list[0] if isinstance(g_list, list) else g_list
if ( if type(g_sample) == DGLGraph: # Doesn't support DGLGraph's derived class
type(g_sample) == DGLGraph
): # Doesn't support DGLGraph's derived class
save_heterographs(filename, g_list, labels, formats) save_heterographs(filename, g_list, labels, formats)
else: else:
raise DGLError( raise DGLError(
...@@ -221,7 +216,6 @@ def load_graph_v1(filename, idx_list=None): ...@@ -221,7 +216,6 @@ def load_graph_v1(filename, idx_list=None):
label_dict = {} label_dict = {}
for k, v in metadata.labels.items(): for k, v in metadata.labels.items():
label_dict[k] = F.zerocopy_from_dgl_ndarray(v) label_dict[k] = F.zerocopy_from_dgl_ndarray(v)
return [gdata.get_graph() for gdata in metadata.graph_data], label_dict return [gdata.get_graph() for gdata in metadata.graph_data], label_dict
......
...@@ -37,6 +37,7 @@ def save_heterographs(filename, g_list, labels, formats): ...@@ -37,6 +37,7 @@ def save_heterographs(filename, g_list, labels, formats):
filename, gdata_list, tensor_dict_to_ndarray_dict(labels), formats filename, gdata_list, tensor_dict_to_ndarray_dict(labels), formats
) )
@register_object("heterograph_serialize.HeteroGraphData") @register_object("heterograph_serialize.HeteroGraphData")
class HeteroGraphData(ObjectBase): class HeteroGraphData(ObjectBase):
"""Object to hold the data to be stored for DGLGraph""" """Object to hold the data to be stored for DGLGraph"""
......
"""KarateClub Dataset """KarateClub Dataset
""" """
import numpy as np
import networkx as nx import networkx as nx
import numpy as np
from .. import backend as F from .. import backend as F
from ..convert import from_networkx
from .dgl_dataset import DGLDataset from .dgl_dataset import DGLDataset
from .utils import deprecate_property from .utils import deprecate_property
from ..convert import from_networkx
__all__ = ['KarateClubDataset', 'KarateClub'] __all__ = ["KarateClubDataset", "KarateClub"]
class KarateClubDataset(DGLDataset): class KarateClubDataset(DGLDataset):
r""" Karate Club dataset for Node Classification r"""Karate Club dataset for Node Classification
Zachary's karate club is a social network of a university Zachary's karate club is a social network of a university
karate club, described in the paper "An Information Flow karate club, described in the paper "An Information Flow
...@@ -46,16 +46,20 @@ class KarateClubDataset(DGLDataset): ...@@ -46,16 +46,20 @@ class KarateClubDataset(DGLDataset):
>>> g = dataset[0] >>> g = dataset[0]
>>> labels = g.ndata['label'] >>> labels = g.ndata['label']
""" """
def __init__(self, transform=None): def __init__(self, transform=None):
super(KarateClubDataset, self).__init__(name='karate_club', transform=transform) super(KarateClubDataset, self).__init__(
name="karate_club", transform=transform
)
def process(self): def process(self):
kc_graph = nx.karate_club_graph() kc_graph = nx.karate_club_graph()
label = np.asarray( label = np.asarray(
[kc_graph.nodes[i]['club'] != 'Mr. Hi' for i in kc_graph.nodes]).astype(np.int64) [kc_graph.nodes[i]["club"] != "Mr. Hi" for i in kc_graph.nodes]
).astype(np.int64)
label = F.tensor(label) label = F.tensor(label)
g = from_networkx(kc_graph) g = from_networkx(kc_graph)
g.ndata['label'] = label g.ndata["label"] = label
self._graph = g self._graph = g
self._data = [g] self._data = [g]
...@@ -65,7 +69,7 @@ class KarateClubDataset(DGLDataset): ...@@ -65,7 +69,7 @@ class KarateClubDataset(DGLDataset):
return 2 return 2
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph object r"""Get graph object
Parameters Parameters
---------- ----------
......
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import os, sys
import pickle as pkl import pickle as pkl
import networkx as nx import networkx as nx
import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import os, sys
from .dgl_dataset import DGLBuiltinDataset
from .utils import download, extract_archive, get_download_dir
from .utils import save_graphs, load_graphs, save_info, load_info, makedirs, _get_dgl_url
from .utils import generate_mask_tensor
from .utils import deprecate_property, deprecate_function
from ..utils import retry_method_with_fix
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from ..utils import retry_method_with_fix
from .dgl_dataset import DGLBuiltinDataset
from .utils import (
_get_dgl_url,
deprecate_function,
deprecate_property,
download,
extract_archive,
generate_mask_tensor,
get_download_dir,
load_graphs,
load_info,
makedirs,
save_graphs,
save_info,
)
class KnowledgeGraphDataset(DGLBuiltinDataset): class KnowledgeGraphDataset(DGLBuiltinDataset):
"""KnowledgeGraph link prediction dataset """KnowledgeGraph link prediction dataset
...@@ -41,22 +55,31 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -41,22 +55,31 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
a transformed version. The :class:`~dgl.DGLGraph` object will be a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access. transformed before every access.
""" """
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None): def __init__(
self,
name,
reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
self._name = name self._name = name
self.reverse = reverse self.reverse = reverse
url = _get_dgl_url('dataset/') + '{}.tgz'.format(name) url = _get_dgl_url("dataset/") + "{}.tgz".format(name)
super(KnowledgeGraphDataset, self).__init__(name, super(KnowledgeGraphDataset, self).__init__(
name,
url=url, url=url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
def download(self): def download(self):
r""" Automatically download data and extract it. r"""Automatically download data and extract it."""
""" tgz_path = os.path.join(self.raw_dir, self.name + ".tgz")
tgz_path = os.path.join(self.raw_dir, self.name + '.tgz')
download(self.url, path=tgz_path) download(self.url, path=tgz_path)
extract_archive(tgz_path, self.raw_path) extract_archive(tgz_path, self.raw_path)
...@@ -66,16 +89,22 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -66,16 +89,22 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
This function will parse these triplets and build the DGLGraph. This function will parse these triplets and build the DGLGraph.
""" """
root_path = self.raw_path root_path = self.raw_path
entity_path = os.path.join(root_path, 'entities.dict') entity_path = os.path.join(root_path, "entities.dict")
relation_path = os.path.join(root_path, 'relations.dict') relation_path = os.path.join(root_path, "relations.dict")
train_path = os.path.join(root_path, 'train.txt') train_path = os.path.join(root_path, "train.txt")
valid_path = os.path.join(root_path, 'valid.txt') valid_path = os.path.join(root_path, "valid.txt")
test_path = os.path.join(root_path, 'test.txt') test_path = os.path.join(root_path, "test.txt")
entity_dict = _read_dictionary(entity_path) entity_dict = _read_dictionary(entity_path)
relation_dict = _read_dictionary(relation_path) relation_dict = _read_dictionary(relation_path)
train = np.asarray(_read_triplets_as_list(train_path, entity_dict, relation_dict)) train = np.asarray(
valid = np.asarray(_read_triplets_as_list(valid_path, entity_dict, relation_dict)) _read_triplets_as_list(train_path, entity_dict, relation_dict)
test = np.asarray(_read_triplets_as_list(test_path, entity_dict, relation_dict)) )
valid = np.asarray(
_read_triplets_as_list(valid_path, entity_dict, relation_dict)
)
test = np.asarray(
_read_triplets_as_list(test_path, entity_dict, relation_dict)
)
num_nodes = len(entity_dict) num_nodes = len(entity_dict)
num_rels = len(relation_dict) num_rels = len(relation_dict)
if self.verbose: if self.verbose:
...@@ -93,25 +122,33 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -93,25 +122,33 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
self._num_nodes = num_nodes self._num_nodes = num_nodes
self._num_rels = num_rels self._num_rels = num_rels
# build graph # build graph
g, data = build_knowledge_graph(num_nodes, num_rels, train, valid, test, reverse=self.reverse) g, data = build_knowledge_graph(
etype, ntype, train_edge_mask, valid_edge_mask, test_edge_mask, train_mask, val_mask, test_mask = data num_nodes, num_rels, train, valid, test, reverse=self.reverse
g.edata['train_edge_mask'] = train_edge_mask )
g.edata['valid_edge_mask'] = valid_edge_mask (
g.edata['test_edge_mask'] = test_edge_mask etype,
g.edata['train_mask'] = train_mask ntype,
g.edata['val_mask'] = val_mask train_edge_mask,
g.edata['test_mask'] = test_mask valid_edge_mask,
g.edata['etype'] = etype test_edge_mask,
g.ndata['ntype'] = ntype train_mask,
val_mask,
test_mask,
) = data
g.edata["train_edge_mask"] = train_edge_mask
g.edata["valid_edge_mask"] = valid_edge_mask
g.edata["test_edge_mask"] = test_edge_mask
g.edata["train_mask"] = train_mask
g.edata["val_mask"] = val_mask
g.edata["test_mask"] = test_mask
g.edata["etype"] = etype
g.ndata["ntype"] = ntype
self._g = g self._g = g
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.save_name + ".bin")
self.save_name + '.bin') info_path = os.path.join(self.save_path, self.save_name + ".pkl")
info_path = os.path.join(self.save_path, if os.path.exists(graph_path) and os.path.exists(info_path):
self.save_name + '.pkl')
if os.path.exists(graph_path) and \
os.path.exists(info_path):
return True return True
return False return False
...@@ -128,49 +165,65 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -128,49 +165,65 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.save_name + ".bin")
self.save_name + '.bin') info_path = os.path.join(self.save_path, self.save_name + ".pkl")
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
save_graphs(str(graph_path), self._g) save_graphs(str(graph_path), self._g)
save_info(str(info_path), {'num_nodes': self.num_nodes, save_info(
'num_rels': self.num_rels}) str(info_path),
{"num_nodes": self.num_nodes, "num_rels": self.num_rels},
)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.save_name + ".bin")
self.save_name + '.bin') info_path = os.path.join(self.save_path, self.save_name + ".pkl")
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
graphs, _ = load_graphs(str(graph_path)) graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path)) info = load_info(str(info_path))
self._num_nodes = info['num_nodes'] self._num_nodes = info["num_nodes"]
self._num_rels = info['num_rels'] self._num_rels = info["num_rels"]
self._g = graphs[0] self._g = graphs[0]
train_mask = self._g.edata['train_edge_mask'].numpy() train_mask = self._g.edata["train_edge_mask"].numpy()
val_mask = self._g.edata['valid_edge_mask'].numpy() val_mask = self._g.edata["valid_edge_mask"].numpy()
test_mask = self._g.edata['test_edge_mask'].numpy() test_mask = self._g.edata["test_edge_mask"].numpy()
# convert mask tensor into bool tensor if possible # convert mask tensor into bool tensor if possible
self._g.edata['train_edge_mask'] = generate_mask_tensor(self._g.edata['train_edge_mask'].numpy()) self._g.edata["train_edge_mask"] = generate_mask_tensor(
self._g.edata['valid_edge_mask'] = generate_mask_tensor(self._g.edata['valid_edge_mask'].numpy()) self._g.edata["train_edge_mask"].numpy()
self._g.edata['test_edge_mask'] = generate_mask_tensor(self._g.edata['test_edge_mask'].numpy()) )
self._g.edata['train_mask'] = generate_mask_tensor(self._g.edata['train_mask'].numpy()) self._g.edata["valid_edge_mask"] = generate_mask_tensor(
self._g.edata['val_mask'] = generate_mask_tensor(self._g.edata['val_mask'].numpy()) self._g.edata["valid_edge_mask"].numpy()
self._g.edata['test_mask'] = generate_mask_tensor(self._g.edata['test_mask'].numpy()) )
self._g.edata["test_edge_mask"] = generate_mask_tensor(
self._g.edata["test_edge_mask"].numpy()
)
self._g.edata["train_mask"] = generate_mask_tensor(
self._g.edata["train_mask"].numpy()
)
self._g.edata["val_mask"] = generate_mask_tensor(
self._g.edata["val_mask"].numpy()
)
self._g.edata["test_mask"] = generate_mask_tensor(
self._g.edata["test_mask"].numpy()
)
# for compatability (with 0.4.x) generate train_idx, valid_idx and test_idx # for compatability (with 0.4.x) generate train_idx, valid_idx and test_idx
etype = self._g.edata['etype'].numpy() etype = self._g.edata["etype"].numpy()
self._etype = etype self._etype = etype
u, v = self._g.all_edges(form='uv') u, v = self._g.all_edges(form="uv")
u = u.numpy() u = u.numpy()
v = v.numpy() v = v.numpy()
train_idx = np.nonzero(train_mask==1) train_idx = np.nonzero(train_mask == 1)
self._train = np.column_stack((u[train_idx], etype[train_idx], v[train_idx])) self._train = np.column_stack(
valid_idx = np.nonzero(val_mask==1) (u[train_idx], etype[train_idx], v[train_idx])
self._valid = np.column_stack((u[valid_idx], etype[valid_idx], v[valid_idx])) )
test_idx = np.nonzero(test_mask==1) valid_idx = np.nonzero(val_mask == 1)
self._test = np.column_stack((u[test_idx], etype[test_idx], v[test_idx])) self._valid = np.column_stack(
(u[valid_idx], etype[valid_idx], v[valid_idx])
)
test_idx = np.nonzero(test_mask == 1)
self._test = np.column_stack(
(u[test_idx], etype[test_idx], v[test_idx])
)
if self.verbose: if self.verbose:
print("# entities: {}".format(self.num_nodes)) print("# entities: {}".format(self.num_nodes))
...@@ -189,22 +242,25 @@ class KnowledgeGraphDataset(DGLBuiltinDataset): ...@@ -189,22 +242,25 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
@property @property
def save_name(self): def save_name(self):
return self.name + '_dgl_graph' return self.name + "_dgl_graph"
def _read_dictionary(filename): def _read_dictionary(filename):
d = {} d = {}
with open(filename, 'r+') as f: with open(filename, "r+") as f:
for line in f: for line in f:
line = line.strip().split('\t') line = line.strip().split("\t")
d[line[1]] = int(line[0]) d[line[1]] = int(line[0])
return d return d
def _read_triplets(filename): def _read_triplets(filename):
with open(filename, 'r+') as f: with open(filename, "r+") as f:
for line in f: for line in f:
processed_line = line.strip().split('\t') processed_line = line.strip().split("\t")
yield processed_line yield processed_line
def _read_triplets_as_list(filename, entity_dict, relation_dict): def _read_triplets_as_list(filename, entity_dict, relation_dict):
l = [] l = []
for triplet in _read_triplets(filename): for triplet in _read_triplets(filename):
...@@ -214,9 +270,11 @@ def _read_triplets_as_list(filename, entity_dict, relation_dict): ...@@ -214,9 +270,11 @@ def _read_triplets_as_list(filename, entity_dict, relation_dict):
l.append([s, r, o]) l.append([s, r, o])
return l return l
def build_knowledge_graph(num_nodes, num_rels, train, valid, test, reverse=True):
""" Create a DGL Homogeneous graph with heterograph info stored as node or edge features. def build_knowledge_graph(
""" num_nodes, num_rels, train, valid, test, reverse=True
):
"""Create a DGL Homogeneous graph with heterograph info stored as node or edge features."""
src = [] src = []
rel = [] rel = []
dst = [] dst = []
...@@ -315,16 +373,40 @@ def build_knowledge_graph(num_nodes, num_rels, train, valid, test, reverse=True) ...@@ -315,16 +373,40 @@ def build_knowledge_graph(num_nodes, num_rels, train, valid, test, reverse=True)
g = dgl_graph((s, d), num_nodes=num_nodes) g = dgl_graph((s, d), num_nodes=num_nodes)
etype = np.concatenate(fg_etype) etype = np.concatenate(fg_etype)
settype = np.concatenate(fg_settype) settype = np.concatenate(fg_settype)
etype = F.tensor(etype, dtype=F.data_type_dict['int64']) etype = F.tensor(etype, dtype=F.data_type_dict["int64"])
train_edge_mask = train_edge_mask train_edge_mask = train_edge_mask
valid_edge_mask = valid_edge_mask valid_edge_mask = valid_edge_mask
test_edge_mask = test_edge_mask test_edge_mask = test_edge_mask
train_mask = generate_mask_tensor(settype == 1) if reverse is True else train_edge_mask train_mask = (
valid_mask = generate_mask_tensor(settype == 2) if reverse is True else valid_edge_mask generate_mask_tensor(settype == 1)
test_mask = generate_mask_tensor(settype == 3) if reverse is True else test_edge_mask if reverse is True
ntype = F.full_1d(num_nodes, 0, dtype=F.data_type_dict['int64'], ctx=F.cpu()) else train_edge_mask
)
valid_mask = (
generate_mask_tensor(settype == 2)
if reverse is True
else valid_edge_mask
)
test_mask = (
generate_mask_tensor(settype == 3)
if reverse is True
else test_edge_mask
)
ntype = F.full_1d(
num_nodes, 0, dtype=F.data_type_dict["int64"], ctx=F.cpu()
)
return g, (
etype,
ntype,
train_edge_mask,
valid_edge_mask,
test_edge_mask,
train_mask,
valid_mask,
test_mask,
)
return g, (etype, ntype, train_edge_mask, valid_edge_mask, test_edge_mask, train_mask, valid_mask, test_mask)
class FB15k237Dataset(KnowledgeGraphDataset): class FB15k237Dataset(KnowledgeGraphDataset):
r"""FB15k237 link prediction dataset. r"""FB15k237 link prediction dataset.
...@@ -396,11 +478,19 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -396,11 +478,19 @@ class FB15k237Dataset(KnowledgeGraphDataset):
>>> >>>
>>> # Train, Validation and Test >>> # Train, Validation and Test
""" """
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None): def __init__(
name = 'FB15k-237' self,
super(FB15k237Dataset, self).__init__(name, reverse, raw_dir, reverse=True,
force_reload, verbose, transform) raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
name = "FB15k-237"
super(FB15k237Dataset, self).__init__(
name, reverse, raw_dir, force_reload, verbose, transform
)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -431,6 +521,7 @@ class FB15k237Dataset(KnowledgeGraphDataset): ...@@ -431,6 +521,7 @@ class FB15k237Dataset(KnowledgeGraphDataset):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(FB15k237Dataset, self).__len__() return super(FB15k237Dataset, self).__len__()
class FB15kDataset(KnowledgeGraphDataset): class FB15kDataset(KnowledgeGraphDataset):
r"""FB15k link prediction dataset. r"""FB15k link prediction dataset.
...@@ -504,11 +595,19 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -504,11 +595,19 @@ class FB15kDataset(KnowledgeGraphDataset):
>>> # Train, Validation and Test >>> # Train, Validation and Test
>>> >>>
""" """
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None): def __init__(
name = 'FB15k' self,
super(FB15kDataset, self).__init__(name, reverse, raw_dir, reverse=True,
force_reload, verbose, transform) raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
name = "FB15k"
super(FB15kDataset, self).__init__(
name, reverse, raw_dir, force_reload, verbose, transform
)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -539,8 +638,9 @@ class FB15kDataset(KnowledgeGraphDataset): ...@@ -539,8 +638,9 @@ class FB15kDataset(KnowledgeGraphDataset):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(FB15kDataset, self).__len__() return super(FB15kDataset, self).__len__()
class WN18Dataset(KnowledgeGraphDataset): class WN18Dataset(KnowledgeGraphDataset):
r""" WN18 link prediction dataset. r"""WN18 link prediction dataset.
The WN18 dataset was introduced in `Translating Embeddings for Modeling The WN18 dataset was introduced in `Translating Embeddings for Modeling
Multi-relational Data <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_. Multi-relational Data <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_.
...@@ -611,11 +711,19 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -611,11 +711,19 @@ class WN18Dataset(KnowledgeGraphDataset):
>>> # Train, Validation and Test >>> # Train, Validation and Test
>>> >>>
""" """
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None): def __init__(
name = 'wn18' self,
super(WN18Dataset, self).__init__(name, reverse, raw_dir, reverse=True,
force_reload, verbose, transform) raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
name = "wn18"
super(WN18Dataset, self).__init__(
name, reverse, raw_dir, force_reload, verbose, transform
)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -646,6 +754,7 @@ class WN18Dataset(KnowledgeGraphDataset): ...@@ -646,6 +754,7 @@ class WN18Dataset(KnowledgeGraphDataset):
r"""The number of graphs in the dataset.""" r"""The number of graphs in the dataset."""
return super(WN18Dataset, self).__len__() return super(WN18Dataset, self).__len__()
def load_data(dataset): def load_data(dataset):
r"""Load knowledge graph dataset for RGCN link prediction tasks r"""Load knowledge graph dataset for RGCN link prediction tasks
...@@ -660,9 +769,9 @@ def load_data(dataset): ...@@ -660,9 +769,9 @@ def load_data(dataset):
------ ------
The dataset object. The dataset object.
""" """
if dataset == 'wn18': if dataset == "wn18":
return WN18Dataset() return WN18Dataset()
elif dataset == 'FB15k': elif dataset == "FB15k":
return FB15kDataset() return FB15kDataset()
elif dataset == 'FB15k-237': elif dataset == "FB15k-237":
return FB15k237Dataset() return FB15k237Dataset()
"""QM7b dataset for graph property prediction (regression).""" """QM7b dataset for graph property prediction (regression)."""
from scipy import io
import numpy as np
import os import os
from .dgl_dataset import DGLDataset import numpy as np
from .utils import download, save_graphs, load_graphs, \ from scipy import io
check_sha1, deprecate_property
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from .dgl_dataset import DGLDataset
from .utils import (
check_sha1,
deprecate_property,
download,
load_graphs,
save_graphs,
)
class QM7bDataset(DGLDataset): class QM7bDataset(DGLDataset):
r"""QM7b dataset for graph property prediction (regression) r"""QM7b dataset for graph property prediction (regression)
...@@ -67,57 +74,69 @@ class QM7bDataset(DGLDataset): ...@@ -67,57 +74,69 @@ class QM7bDataset(DGLDataset):
>>> >>>
""" """
_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \ _url = (
'datasets/qm7b.mat' "http://deepchem.io.s3-website-us-west-1.amazonaws.com/"
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392' "datasets/qm7b.mat"
)
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None): _sha1_str = "4102c744bb9d6fd7b40ac67a300e49cd87e28392"
super(QM7bDataset, self).__init__(name='qm7b',
def __init__(
self, raw_dir=None, force_reload=False, verbose=False, transform=None
):
super(QM7bDataset, self).__init__(
name="qm7b",
url=self._url, url=self._url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
def process(self): def process(self):
mat_path = self.raw_path + '.mat' mat_path = self.raw_path + ".mat"
self.graphs, self.label = self._load_graph(mat_path) self.graphs, self.label = self._load_graph(mat_path)
def _load_graph(self, filename): def _load_graph(self, filename):
data = io.loadmat(filename) data = io.loadmat(filename)
labels = F.tensor(data['T'], dtype=F.data_type_dict['float32']) labels = F.tensor(data["T"], dtype=F.data_type_dict["float32"])
feats = data['X'] feats = data["X"]
num_graphs = labels.shape[0] num_graphs = labels.shape[0]
graphs = [] graphs = []
for i in range(num_graphs): for i in range(num_graphs):
edge_list = feats[i].nonzero() edge_list = feats[i].nonzero()
g = dgl_graph(edge_list) g = dgl_graph(edge_list)
g.edata['h'] = F.tensor(feats[i][edge_list[0], edge_list[1]].reshape(-1, 1), g.edata["h"] = F.tensor(
dtype=F.data_type_dict['float32']) feats[i][edge_list[0], edge_list[1]].reshape(-1, 1),
dtype=F.data_type_dict["float32"],
)
graphs.append(g) graphs.append(g)
return graphs, labels return graphs, labels
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(str(graph_path), self.graphs, {'labels': self.label}) save_graphs(str(graph_path), self.graphs, {"labels": self.label})
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
return os.path.exists(graph_path) return os.path.exists(graph_path)
def load(self): def load(self):
graphs, label_dict = load_graphs(os.path.join(self.save_path, 'dgl_graph.bin')) graphs, label_dict = load_graphs(
os.path.join(self.save_path, "dgl_graph.bin")
)
self.graphs = graphs self.graphs = graphs
self.label = label_dict['labels'] self.label = label_dict["labels"]
def download(self): def download(self):
file_path = os.path.join(self.raw_dir, self.name + '.mat') file_path = os.path.join(self.raw_dir, self.name + ".mat")
download(self.url, path=file_path) download(self.url, path=file_path)
if not check_sha1(file_path, self._sha1_str): if not check_sha1(file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.' raise UserWarning(
'The repo may be outdated or download may be incomplete. ' "File {} is downloaded but the content hash does not match."
'Otherwise you can create an issue for it.'.format(self.name)) "The repo may be outdated or download may be incomplete. "
"Otherwise you can create an issue for it.".format(self.name)
)
@property @property
def num_tasks(self): def num_tasks(self):
...@@ -130,7 +149,7 @@ class QM7bDataset(DGLDataset): ...@@ -130,7 +149,7 @@ class QM7bDataset(DGLDataset):
return 14 return 14
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph and label by index r"""Get graph and label by index
Parameters Parameters
---------- ----------
......
"""QM9 dataset for graph property prediction (regression).""" """QM9 dataset for graph property prediction (regression)."""
import os import os
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
from .dgl_dataset import DGLDataset from .. import backend as F
from .utils import download, _get_dgl_url
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from ..transforms import to_bidirected from ..transforms import to_bidirected
from .. import backend as F
from .dgl_dataset import DGLDataset
from .utils import _get_dgl_url, download
class QM9Dataset(DGLDataset): class QM9Dataset(DGLDataset):
r"""QM9 dataset for graph property prediction (regression) r"""QM9 dataset for graph property prediction (regression)
...@@ -103,39 +106,44 @@ class QM9Dataset(DGLDataset): ...@@ -103,39 +106,44 @@ class QM9Dataset(DGLDataset):
>>> >>>
""" """
def __init__(self, def __init__(
self,
label_keys, label_keys,
cutoff=5.0, cutoff=5.0,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=False, verbose=False,
transform=None): transform=None,
):
self.cutoff = cutoff self.cutoff = cutoff
self.label_keys = label_keys self.label_keys = label_keys
self._url = _get_dgl_url('dataset/qm9_eV.npz') self._url = _get_dgl_url("dataset/qm9_eV.npz")
super(QM9Dataset, self).__init__(name='qm9', super(QM9Dataset, self).__init__(
name="qm9",
url=self._url, url=self._url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
def process(self): def process(self):
npz_path = f'{self.raw_dir}/qm9_eV.npz' npz_path = f"{self.raw_dir}/qm9_eV.npz"
data_dict = np.load(npz_path, allow_pickle=True) data_dict = np.load(npz_path, allow_pickle=True)
# data_dict['N'] contains the number of atoms in each molecule. # data_dict['N'] contains the number of atoms in each molecule.
# Atomic properties (Z and R) of all molecules are concatenated as single tensors, # Atomic properties (Z and R) of all molecules are concatenated as single tensors,
# so you need this value to select the correct atoms for each molecule. # so you need this value to select the correct atoms for each molecule.
self.N = data_dict['N'] self.N = data_dict["N"]
self.R = data_dict['R'] self.R = data_dict["R"]
self.Z = data_dict['Z'] self.Z = data_dict["Z"]
self.label = np.stack([data_dict[key] for key in self.label_keys], axis=1) self.label = np.stack(
[data_dict[key] for key in self.label_keys], axis=1
)
self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)]) self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
def download(self): def download(self):
file_path = f'{self.raw_dir}/qm9_eV.npz' file_path = f"{self.raw_dir}/qm9_eV.npz"
if not os.path.exists(file_path): if not os.path.exists(file_path):
download(self._url, path=file_path) download(self._url, path=file_path)
...@@ -160,7 +168,7 @@ class QM9Dataset(DGLDataset): ...@@ -160,7 +168,7 @@ class QM9Dataset(DGLDataset):
return self.label.shape[1] return self.label.shape[1]
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph and label by index r"""Get graph and label by index
Parameters Parameters
---------- ----------
...@@ -178,18 +186,22 @@ class QM9Dataset(DGLDataset): ...@@ -178,18 +186,22 @@ class QM9Dataset(DGLDataset):
Tensor Tensor
Property values of molecular graphs Property values of molecular graphs
""" """
label = F.tensor(self.label[idx], dtype=F.data_type_dict['float32']) label = F.tensor(self.label[idx], dtype=F.data_type_dict["float32"])
n_atoms = self.N[idx] n_atoms = self.N[idx]
R = self.R[self.N_cumsum[idx]:self.N_cumsum[idx + 1]] R = self.R[self.N_cumsum[idx] : self.N_cumsum[idx + 1]]
dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1) dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(n_atoms, dtype=np.bool_) adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(
n_atoms, dtype=np.bool_
)
adj = adj.tocoo() adj = adj.tocoo()
u, v = F.tensor(adj.row), F.tensor(adj.col) u, v = F.tensor(adj.row), F.tensor(adj.col)
g = dgl_graph((u, v)) g = dgl_graph((u, v))
g = to_bidirected(g) g = to_bidirected(g)
g.ndata['R'] = F.tensor(R, dtype=F.data_type_dict['float32']) g.ndata["R"] = F.tensor(R, dtype=F.data_type_dict["float32"])
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]], g.ndata["Z"] = F.tensor(
dtype=F.data_type_dict['int64']) self.Z[self.N_cumsum[idx] : self.N_cumsum[idx + 1]],
dtype=F.data_type_dict["int64"],
)
if self._transform is not None: if self._transform is not None:
g = self._transform(g) g = self._transform(g)
...@@ -205,4 +217,5 @@ class QM9Dataset(DGLDataset): ...@@ -205,4 +217,5 @@ class QM9Dataset(DGLDataset):
""" """
return self.label.shape[0] return self.label.shape[0]
QM9 = QM9Dataset QM9 = QM9Dataset
""" QM9 dataset for graph property prediction (regression) """ """ QM9 dataset for graph property prediction (regression) """
import os import os
import numpy as np import numpy as np
from .dgl_dataset import DGLDataset
from .utils import download, extract_archive, _get_dgl_url
from ..convert import graph as dgl_graph
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph
from .dgl_dataset import DGLDataset
from .utils import _get_dgl_url, download, extract_archive
class QM9EdgeDataset(DGLDataset): class QM9EdgeDataset(DGLDataset):
...@@ -129,19 +131,40 @@ class QM9EdgeDataset(DGLDataset): ...@@ -129,19 +131,40 @@ class QM9EdgeDataset(DGLDataset):
>>> >>>
""" """
keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'U0_atom', keys = [
'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'] "mu",
"alpha",
"homo",
"lumo",
"gap",
"r2",
"zpve",
"U0",
"U",
"H",
"G",
"Cv",
"U0_atom",
"U_atom",
"H_atom",
"G_atom",
"A",
"B",
"C",
]
map_dict = {} map_dict = {}
for i, key in enumerate(keys): for i, key in enumerate(keys):
map_dict[key] = i map_dict[key] = i
def __init__(self, def __init__(
self,
label_keys=None, label_keys=None,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=True, verbose=True,
transform=None): transform=None,
):
if label_keys is None: if label_keys is None:
self.label_keys = None self.label_keys = None
self.num_labels = 19 self.num_labels = 19
...@@ -149,18 +172,19 @@ class QM9EdgeDataset(DGLDataset): ...@@ -149,18 +172,19 @@ class QM9EdgeDataset(DGLDataset):
self.label_keys = [self.map_dict[i] for i in label_keys] self.label_keys = [self.map_dict[i] for i in label_keys]
self.num_labels = len(label_keys) self.num_labels = len(label_keys)
self._url = _get_dgl_url("dataset/qm9_edge.npz")
self._url = _get_dgl_url('dataset/qm9_edge.npz') super(QM9EdgeDataset, self).__init__(
name="qm9Edge",
super(QM9EdgeDataset, self).__init__(name='qm9Edge',
raw_dir=raw_dir, raw_dir=raw_dir,
url=self._url, url=self._url,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
def download(self): def download(self):
file_path = f'{self.raw_dir}/qm9_edge.npz' file_path = f"{self.raw_dir}/qm9_edge.npz"
if not os.path.exists(file_path): if not os.path.exists(file_path):
download(self._url, path=file_path) download(self._url, path=file_path)
...@@ -168,11 +192,12 @@ class QM9EdgeDataset(DGLDataset): ...@@ -168,11 +192,12 @@ class QM9EdgeDataset(DGLDataset):
self.load() self.load()
def has_cache(self): def has_cache(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz' npz_path = f"{self.raw_dir}/qm9_edge.npz"
return os.path.exists(npz_path) return os.path.exists(npz_path)
def save(self): def save(self):
np.savez_compressed(f'{self.raw_dir}/qm9_edge.npz', np.savez_compressed(
f"{self.raw_dir}/qm9_edge.npz",
n_node=self.n_node, n_node=self.n_node,
n_edge=self.n_edge, n_edge=self.n_edge,
node_attr=self.node_attr, node_attr=self.node_attr,
...@@ -180,27 +205,28 @@ class QM9EdgeDataset(DGLDataset): ...@@ -180,27 +205,28 @@ class QM9EdgeDataset(DGLDataset):
edge_attr=self.edge_attr, edge_attr=self.edge_attr,
src=self.src, src=self.src,
dst=self.dst, dst=self.dst,
targets=self.targets) targets=self.targets,
)
def load(self): def load(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz' npz_path = f"{self.raw_dir}/qm9_edge.npz"
data_dict = np.load(npz_path, allow_pickle=True) data_dict = np.load(npz_path, allow_pickle=True)
self.n_node = data_dict['n_node'] self.n_node = data_dict["n_node"]
self.n_edge = data_dict['n_edge'] self.n_edge = data_dict["n_edge"]
self.node_attr = data_dict['node_attr'] self.node_attr = data_dict["node_attr"]
self.node_pos = data_dict['node_pos'] self.node_pos = data_dict["node_pos"]
self.edge_attr = data_dict['edge_attr'] self.edge_attr = data_dict["edge_attr"]
self.targets = data_dict['targets'] self.targets = data_dict["targets"]
self.src = data_dict['src'] self.src = data_dict["src"]
self.dst = data_dict['dst'] self.dst = data_dict["dst"]
self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)]) self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])
self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)]) self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph and label by index r"""Get graph and label by index
Parameters Parameters
---------- ----------
...@@ -220,18 +246,26 @@ class QM9EdgeDataset(DGLDataset): ...@@ -220,18 +246,26 @@ class QM9EdgeDataset(DGLDataset):
Property values of molecular graphs Property values of molecular graphs
""" """
pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]] pos = self.node_pos[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]
src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]] src = self.src[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]] dst = self.dst[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
g = dgl_graph((src, dst)) g = dgl_graph((src, dst))
g.ndata['pos'] = F.tensor(pos, dtype=F.data_type_dict['float32']) g.ndata["pos"] = F.tensor(pos, dtype=F.data_type_dict["float32"])
g.ndata['attr'] = F.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]], dtype=F.data_type_dict['float32']) g.ndata["attr"] = F.tensor(
g.edata['edge_attr'] = F.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]], dtype=F.data_type_dict['float32']) self.node_attr[self.n_cumsum[idx] : self.n_cumsum[idx + 1]],
dtype=F.data_type_dict["float32"],
)
label = F.tensor(self.targets[idx][self.label_keys], dtype=F.data_type_dict['float32']) g.edata["edge_attr"] = F.tensor(
self.edge_attr[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]],
dtype=F.data_type_dict["float32"],
)
label = F.tensor(
self.targets[idx][self.label_keys],
dtype=F.data_type_dict["float32"],
)
if self._transform is not None: if self._transform is not None:
g = self._transform(g) g = self._transform(g)
...@@ -239,7 +273,7 @@ class QM9EdgeDataset(DGLDataset): ...@@ -239,7 +273,7 @@ class QM9EdgeDataset(DGLDataset):
return g, label return g, label
def __len__(self): def __len__(self):
r""" Number of graphs in the dataset. r"""Number of graphs in the dataset.
Returns Returns
------- -------
......
""" Reddit dataset for community detection """ """ Reddit dataset for community detection """
from __future__ import absolute_import from __future__ import absolute_import
import scipy.sparse as sp
import numpy as np
import os import os
from .dgl_dataset import DGLBuiltinDataset import numpy as np
from .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs, deprecate_property
import scipy.sparse as sp
from .. import backend as F from .. import backend as F
from ..convert import from_scipy from ..convert import from_scipy
from ..transforms import reorder_graph from ..transforms import reorder_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import (
_get_dgl_url,
deprecate_property,
generate_mask_tensor,
load_graphs,
save_graphs,
)
class RedditDataset(DGLBuiltinDataset): class RedditDataset(DGLBuiltinDataset):
r""" Reddit dataset for community detection (node classification) r"""Reddit dataset for community detection (node classification)
This is a graph dataset from Reddit posts made in the month of September, 2014. This is a graph dataset from Reddit posts made in the month of September, 2014.
The node label in this case is the community, or “subreddit”, that a post belongs to. The node label in this case is the community, or “subreddit”, that a post belongs to.
...@@ -73,24 +82,36 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -73,24 +82,36 @@ class RedditDataset(DGLBuiltinDataset):
>>> >>>
>>> # Train, Validation and Test >>> # Train, Validation and Test
""" """
def __init__(self, self_loop=False, raw_dir=None, force_reload=False,
verbose=False, transform=None): def __init__(
self,
self_loop=False,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
self_loop_str = "" self_loop_str = ""
if self_loop: if self_loop:
self_loop_str = "_self_loop" self_loop_str = "_self_loop"
_url = _get_dgl_url("dataset/reddit{}.zip".format(self_loop_str)) _url = _get_dgl_url("dataset/reddit{}.zip".format(self_loop_str))
self._self_loop_str = self_loop_str self._self_loop_str = self_loop_str
super(RedditDataset, self).__init__(name='reddit{}'.format(self_loop_str), super(RedditDataset, self).__init__(
name="reddit{}".format(self_loop_str),
url=_url, url=_url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
def process(self): def process(self):
# graph # graph
coo_adj = sp.load_npz(os.path.join( coo_adj = sp.load_npz(
self.raw_path, "reddit{}_graph.npz".format(self._self_loop_str))) os.path.join(
self.raw_path, "reddit{}_graph.npz".format(self._self_loop_str)
)
)
self._graph = from_scipy(coo_adj) self._graph = from_scipy(coo_adj)
# features and labels # features and labels
reddit_data = np.load(os.path.join(self.raw_path, "reddit_data.npz")) reddit_data = np.load(os.path.join(self.raw_path, "reddit_data.npz"))
...@@ -98,48 +119,74 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -98,48 +119,74 @@ class RedditDataset(DGLBuiltinDataset):
labels = reddit_data["label"] labels = reddit_data["label"]
# tarin/val/test indices # tarin/val/test indices
node_types = reddit_data["node_types"] node_types = reddit_data["node_types"]
train_mask = (node_types == 1) train_mask = node_types == 1
val_mask = (node_types == 2) val_mask = node_types == 2
test_mask = (node_types == 3) test_mask = node_types == 3
self._graph.ndata['train_mask'] = generate_mask_tensor(train_mask) self._graph.ndata["train_mask"] = generate_mask_tensor(train_mask)
self._graph.ndata['val_mask'] = generate_mask_tensor(val_mask) self._graph.ndata["val_mask"] = generate_mask_tensor(val_mask)
self._graph.ndata['test_mask'] = generate_mask_tensor(test_mask) self._graph.ndata["test_mask"] = generate_mask_tensor(test_mask)
self._graph.ndata['feat'] = F.tensor(features, dtype=F.data_type_dict['float32']) self._graph.ndata["feat"] = F.tensor(
self._graph.ndata['label'] = F.tensor(labels, dtype=F.data_type_dict['int64']) features, dtype=F.data_type_dict["float32"]
)
self._graph.ndata["label"] = F.tensor(
labels, dtype=F.data_type_dict["int64"]
)
self._graph = reorder_graph( self._graph = reorder_graph(
self._graph, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) self._graph,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
self._print_info() self._print_info()
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
if os.path.exists(graph_path): if os.path.exists(graph_path):
return True return True
return False return False
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(graph_path, self._graph) save_graphs(graph_path, self._graph)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
graphs, _ = load_graphs(graph_path) graphs, _ = load_graphs(graph_path)
self._graph = graphs[0] self._graph = graphs[0]
self._graph.ndata['train_mask'] = generate_mask_tensor(self._graph.ndata['train_mask'].numpy()) self._graph.ndata["train_mask"] = generate_mask_tensor(
self._graph.ndata['val_mask'] = generate_mask_tensor(self._graph.ndata['val_mask'].numpy()) self._graph.ndata["train_mask"].numpy()
self._graph.ndata['test_mask'] = generate_mask_tensor(self._graph.ndata['test_mask'].numpy()) )
self._graph.ndata["val_mask"] = generate_mask_tensor(
self._graph.ndata["val_mask"].numpy()
)
self._graph.ndata["test_mask"] = generate_mask_tensor(
self._graph.ndata["test_mask"].numpy()
)
self._print_info() self._print_info()
def _print_info(self): def _print_info(self):
if self.verbose: if self.verbose:
print('Finished data loading.') print("Finished data loading.")
print(' NumNodes: {}'.format(self._graph.number_of_nodes())) print(" NumNodes: {}".format(self._graph.number_of_nodes()))
print(' NumEdges: {}'.format(self._graph.number_of_edges())) print(" NumEdges: {}".format(self._graph.number_of_edges()))
print(' NumFeats: {}'.format(self._graph.ndata['feat'].shape[1])) print(" NumFeats: {}".format(self._graph.ndata["feat"].shape[1]))
print(' NumClasses: {}'.format(self.num_classes)) print(" NumClasses: {}".format(self.num_classes))
print(' NumTrainingSamples: {}'.format(F.nonzero_1d(self._graph.ndata['train_mask']).shape[0])) print(
print(' NumValidationSamples: {}'.format(F.nonzero_1d(self._graph.ndata['val_mask']).shape[0])) " NumTrainingSamples: {}".format(
print(' NumTestSamples: {}'.format(F.nonzero_1d(self._graph.ndata['test_mask']).shape[0])) F.nonzero_1d(self._graph.ndata["train_mask"]).shape[0]
)
)
print(
" NumValidationSamples: {}".format(
F.nonzero_1d(self._graph.ndata["val_mask"]).shape[0]
)
)
print(
" NumTestSamples: {}".format(
F.nonzero_1d(self._graph.ndata["test_mask"]).shape[0]
)
)
@property @property
def num_classes(self): def num_classes(self):
...@@ -147,7 +194,7 @@ class RedditDataset(DGLBuiltinDataset): ...@@ -147,7 +194,7 @@ class RedditDataset(DGLBuiltinDataset):
return 41 return 41
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph by index r"""Get graph by index
Parameters Parameters
---------- ----------
......
...@@ -4,19 +4,28 @@ Including: ...@@ -4,19 +4,28 @@ Including:
""" """
from __future__ import absolute_import from __future__ import absolute_import
import os
from collections import OrderedDict from collections import OrderedDict
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import os
from .dgl_dataset import DGLBuiltinDataset
from .. import backend as F from .. import backend as F
from .utils import _get_dgl_url, save_graphs, save_info, load_graphs, \
load_info, deprecate_property
from ..convert import from_networkx from ..convert import from_networkx
__all__ = ['SST', 'SSTDataset'] from .dgl_dataset import DGLBuiltinDataset
from .utils import (
_get_dgl_url,
deprecate_property,
load_graphs,
load_info,
save_graphs,
save_info,
)
__all__ = ["SST", "SSTDataset"]
class SSTDataset(DGLBuiltinDataset): class SSTDataset(DGLBuiltinDataset):
...@@ -104,45 +113,58 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -104,45 +113,58 @@ class SSTDataset(DGLBuiltinDataset):
PAD_WORD = -1 # special pad word id PAD_WORD = -1 # special pad word id
UNK_WORD = -1 # out-of-vocabulary word id UNK_WORD = -1 # out-of-vocabulary word id
def __init__(self, def __init__(
mode='train', self,
mode="train",
glove_embed_file=None, glove_embed_file=None,
vocab_file=None, vocab_file=None,
raw_dir=None, raw_dir=None,
force_reload=False, force_reload=False,
verbose=False, verbose=False,
transform=None): transform=None,
assert mode in ['train', 'dev', 'test', 'tiny'] ):
_url = _get_dgl_url('dataset/sst.zip') assert mode in ["train", "dev", "test", "tiny"]
self._glove_embed_file = glove_embed_file if mode == 'train' else None _url = _get_dgl_url("dataset/sst.zip")
self._glove_embed_file = glove_embed_file if mode == "train" else None
self.mode = mode self.mode = mode
self._vocab_file = vocab_file self._vocab_file = vocab_file
super(SSTDataset, self).__init__(name='sst', super(SSTDataset, self).__init__(
name="sst",
url=_url, url=_url,
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
def process(self): def process(self):
from nltk.corpus.reader import BracketParseCorpusReader from nltk.corpus.reader import BracketParseCorpusReader
# load vocab file # load vocab file
self._vocab = OrderedDict() self._vocab = OrderedDict()
vocab_file = self._vocab_file if self._vocab_file is not None else os.path.join(self.raw_path, 'vocab.txt') vocab_file = (
with open(vocab_file, encoding='utf-8') as vf: self._vocab_file
if self._vocab_file is not None
else os.path.join(self.raw_path, "vocab.txt")
)
with open(vocab_file, encoding="utf-8") as vf:
for line in vf.readlines(): for line in vf.readlines():
line = line.strip() line = line.strip()
self._vocab[line] = len(self._vocab) self._vocab[line] = len(self._vocab)
# filter glove # filter glove
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): if self._glove_embed_file is not None and os.path.exists(
self._glove_embed_file
):
glove_emb = {} glove_emb = {}
with open(self._glove_embed_file, 'r', encoding='utf-8') as pf: with open(self._glove_embed_file, "r", encoding="utf-8") as pf:
for line in pf.readlines(): for line in pf.readlines():
sp = line.split(' ') sp = line.split(" ")
if sp[0].lower() in self._vocab: if sp[0].lower() in self._vocab:
glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]]) glove_emb[sp[0].lower()] = np.asarray(
files = ['{}.txt'.format(self.mode)] [float(x) for x in sp[1:]]
)
files = ["{}.txt".format(self.mode)]
corpus = BracketParseCorpusReader(self.raw_path, files) corpus = BracketParseCorpusReader(self.raw_path, files)
sents = corpus.parsed_sents(files[0]) sents = corpus.parsed_sents(files[0])
...@@ -150,15 +172,27 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -150,15 +172,27 @@ class SSTDataset(DGLBuiltinDataset):
pretrained_emb = [] pretrained_emb = []
fail_cnt = 0 fail_cnt = 0
for line in self._vocab.keys(): for line in self._vocab.keys():
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): if self._glove_embed_file is not None and os.path.exists(
self._glove_embed_file
):
if not line.lower() in glove_emb: if not line.lower() in glove_emb:
fail_cnt += 1 fail_cnt += 1
pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300))) pretrained_emb.append(
glove_emb.get(
line.lower(), np.random.uniform(-0.05, 0.05, 300)
)
)
self._pretrained_emb = None self._pretrained_emb = None
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file): if self._glove_embed_file is not None and os.path.exists(
self._glove_embed_file
):
self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0)) self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
print('Miss word in GloVe {0:.4f}'.format(1.0 * fail_cnt / len(self._pretrained_emb))) print(
"Miss word in GloVe {0:.4f}".format(
1.0 * fail_cnt / len(self._pretrained_emb)
)
)
# build trees # build trees
self._trees = [] self._trees = []
for sent in sents: for sent in sents:
...@@ -175,44 +209,46 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -175,44 +209,46 @@ class SSTDataset(DGLBuiltinDataset):
word = self.vocab.get(child[0].lower(), self.UNK_WORD) word = self.vocab.get(child[0].lower(), self.UNK_WORD)
g.add_node(cid, x=word, y=int(child.label()), mask=1) g.add_node(cid, x=word, y=int(child.label()), mask=1)
else: else:
g.add_node(cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0) g.add_node(
cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0
)
_rec_build(cid, child) _rec_build(cid, child)
g.add_edge(cid, nid) g.add_edge(cid, nid)
# add root # add root
g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0) g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0)
_rec_build(0, root) _rec_build(0, root)
ret = from_networkx(g, node_attrs=['x', 'y', 'mask']) ret = from_networkx(g, node_attrs=["x", "y", "mask"])
return ret return ret
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
vocab_path = os.path.join(self.save_path, 'vocab.pkl') vocab_path = os.path.join(self.save_path, "vocab.pkl")
return os.path.exists(graph_path) and os.path.exists(vocab_path) return os.path.exists(graph_path) and os.path.exists(vocab_path)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
save_graphs(graph_path, self._trees) save_graphs(graph_path, self._trees)
vocab_path = os.path.join(self.save_path, 'vocab.pkl') vocab_path = os.path.join(self.save_path, "vocab.pkl")
save_info(vocab_path, {'vocab': self.vocab}) save_info(vocab_path, {"vocab": self.vocab})
if self.pretrained_emb: if self.pretrained_emb:
emb_path = os.path.join(self.save_path, 'emb.pkl') emb_path = os.path.join(self.save_path, "emb.pkl")
save_info(emb_path, {'embed': self.pretrained_emb}) save_info(emb_path, {"embed": self.pretrained_emb})
def load(self): def load(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.mode + "_dgl_graph.bin")
vocab_path = os.path.join(self.save_path, 'vocab.pkl') vocab_path = os.path.join(self.save_path, "vocab.pkl")
emb_path = os.path.join(self.save_path, 'emb.pkl') emb_path = os.path.join(self.save_path, "emb.pkl")
self._trees = load_graphs(graph_path)[0] self._trees = load_graphs(graph_path)[0]
self._vocab = load_info(vocab_path)['vocab'] self._vocab = load_info(vocab_path)["vocab"]
self._pretrained_emb = None self._pretrained_emb = None
if os.path.exists(emb_path): if os.path.exists(emb_path):
self._pretrained_emb = load_info(emb_path)['embed'] self._pretrained_emb = load_info(emb_path)["embed"]
@property @property
def vocab(self): def vocab(self):
r""" Vocabulary r"""Vocabulary
Returns Returns
------- -------
...@@ -226,7 +262,7 @@ class SSTDataset(DGLBuiltinDataset): ...@@ -226,7 +262,7 @@ class SSTDataset(DGLBuiltinDataset):
return self._pretrained_emb return self._pretrained_emb
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph by index r"""Get graph by index
Parameters Parameters
---------- ----------
......
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np
import os import os
from .dgl_dataset import DGLBuiltinDataset import numpy as np
from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import load_graphs, load_info, loadtxt, save_graphs, save_info
class LegacyTUDataset(DGLBuiltinDataset): class LegacyTUDataset(DGLBuiltinDataset):
r"""LegacyTUDataset contains lots of graph kernel datasets for graph classification. r"""LegacyTUDataset contains lots of graph kernel datasets for graph classification.
...@@ -77,38 +81,60 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -77,38 +81,60 @@ class LegacyTUDataset(DGLBuiltinDataset):
_url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip" _url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
def __init__(self, name, use_pandas=False, def __init__(
hidden_size=10, max_allow_node=None, self,
raw_dir=None, force_reload=False, verbose=False, transform=None): name,
use_pandas=False,
hidden_size=10,
max_allow_node=None,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
url = self._url.format(name) url = self._url.format(name)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.max_allow_node = max_allow_node self.max_allow_node = max_allow_node
self.use_pandas = use_pandas self.use_pandas = use_pandas
super(LegacyTUDataset, self).__init__(name=name, url=url, raw_dir=raw_dir, super(LegacyTUDataset, self).__init__(
name=name,
url=url,
raw_dir=raw_dir,
hash_key=(name, use_pandas, hidden_size, max_allow_node), hash_key=(name, use_pandas, hidden_size, max_allow_node),
force_reload=force_reload, verbose=verbose, transform=transform) force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self): def process(self):
self.data_mode = None self.data_mode = None
if self.use_pandas: if self.use_pandas:
import pandas as pd import pandas as pd
DS_edge_list = self._idx_from_zero( DS_edge_list = self._idx_from_zero(
pd.read_csv(self._file_path("A"), delimiter=",", dtype=int, header=None).values) pd.read_csv(
self._file_path("A"), delimiter=",", dtype=int, header=None
).values
)
else: else:
DS_edge_list = self._idx_from_zero( DS_edge_list = self._idx_from_zero(
np.genfromtxt(self._file_path("A"), delimiter=",", dtype=int)) np.genfromtxt(self._file_path("A"), delimiter=",", dtype=int)
)
DS_indicator = self._idx_from_zero( DS_indicator = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_indicator"), dtype=int)) np.genfromtxt(self._file_path("graph_indicator"), dtype=int)
)
if os.path.exists(self._file_path("graph_labels")): if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_from_zero( DS_graph_labels = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_labels"), dtype=int)) np.genfromtxt(self._file_path("graph_labels"), dtype=int)
)
self.num_labels = max(DS_graph_labels) + 1 self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels self.graph_labels = DS_graph_labels
elif os.path.exists(self._file_path("graph_attributes")): elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = np.genfromtxt(self._file_path("graph_attributes"), dtype=float) DS_graph_labels = np.genfromtxt(
self._file_path("graph_attributes"), dtype=float
)
self.num_labels = None self.num_labels = None
self.graph_labels = DS_graph_labels self.graph_labels = DS_graph_labels
else: else:
...@@ -130,33 +156,42 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -130,33 +156,42 @@ class LegacyTUDataset(DGLBuiltinDataset):
try: try:
DS_node_labels = self._idx_from_zero( DS_node_labels = self._idx_from_zero(
np.loadtxt(self._file_path("node_labels"), dtype=int)) np.loadtxt(self._file_path("node_labels"), dtype=int)
g.ndata['node_label'] = F.tensor(DS_node_labels) )
g.ndata["node_label"] = F.tensor(DS_node_labels)
one_hot_node_labels = self._to_onehot(DS_node_labels) one_hot_node_labels = self._to_onehot(DS_node_labels)
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = F.tensor(one_hot_node_labels[idxs, :], F.float32) g.ndata["feat"] = F.tensor(
one_hot_node_labels[idxs, :], F.float32
)
self.data_mode = "node_label" self.data_mode = "node_label"
except IOError: except IOError:
print("No Node Label Data") print("No Node Label Data")
try: try:
DS_node_attr = np.loadtxt( DS_node_attr = np.loadtxt(
self._file_path("node_attributes"), delimiter=",") self._file_path("node_attributes"), delimiter=","
)
if DS_node_attr.ndim == 1: if DS_node_attr.ndim == 1:
DS_node_attr = np.expand_dims(DS_node_attr, -1) DS_node_attr = np.expand_dims(DS_node_attr, -1)
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = F.tensor(DS_node_attr[idxs, :], F.float32) g.ndata["feat"] = F.tensor(DS_node_attr[idxs, :], F.float32)
self.data_mode = "node_attr" self.data_mode = "node_attr"
except IOError: except IOError:
print("No Node Attribute Data") print("No Node Attribute Data")
if 'feat' not in g.ndata.keys(): if "feat" not in g.ndata.keys():
for idxs, g in zip(node_idx_list, self.graph_lists): for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = F.ones((g.number_of_nodes(), self.hidden_size), g.ndata["feat"] = F.ones(
F.float32, F.cpu()) (g.number_of_nodes(), self.hidden_size), F.float32, F.cpu()
)
self.data_mode = "constant" self.data_mode = "constant"
if self.verbose: if self.verbose:
print("Use Constant one as Feature with hidden size {}".format(self.hidden_size)) print(
"Use Constant one as Feature with hidden size {}".format(
self.hidden_size
)
)
# remove graphs that are too large by user given standard # remove graphs that are too large by user given standard
# optional pre-processing steop in conformity with Rex Ying's original # optional pre-processing steop in conformity with Rex Ying's original
...@@ -165,39 +200,56 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -165,39 +200,56 @@ class LegacyTUDataset(DGLBuiltinDataset):
preserve_idx = [] preserve_idx = []
if self.verbose: if self.verbose:
print("original dataset length : ", len(self.graph_lists)) print("original dataset length : ", len(self.graph_lists))
for (i, g) in enumerate(self.graph_lists): for i, g in enumerate(self.graph_lists):
if g.number_of_nodes() <= self.max_allow_node: if g.number_of_nodes() <= self.max_allow_node:
preserve_idx.append(i) preserve_idx.append(i)
self.graph_lists = [self.graph_lists[i] for i in preserve_idx] self.graph_lists = [self.graph_lists[i] for i in preserve_idx]
if self.verbose: if self.verbose:
print("after pruning graphs that are too big : ", len(self.graph_lists)) print(
"after pruning graphs that are too big : ",
len(self.graph_lists),
)
self.graph_labels = [self.graph_labels[i] for i in preserve_idx] self.graph_labels = [self.graph_labels[i] for i in preserve_idx]
self.max_num_node = self.max_allow_node self.max_num_node = self.max_allow_node
self.graph_labels = F.tensor(self.graph_labels) self.graph_labels = F.tensor(self.graph_labels)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.bin'.format(self.name, self.hash)) graph_path = os.path.join(
info_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
label_dict = {'labels': self.graph_labels} )
info_dict = {'max_num_node': self.max_num_node, info_path = os.path.join(
'num_labels': self.num_labels} self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
label_dict = {"labels": self.graph_labels}
info_dict = {
"max_num_node": self.max_num_node,
"num_labels": self.num_labels,
}
save_graphs(str(graph_path), self.graph_lists, label_dict) save_graphs(str(graph_path), self.graph_lists, label_dict)
save_info(str(info_path), info_dict) save_info(str(info_path), info_dict)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.bin'.format(self.name, self.hash)) graph_path = os.path.join(
info_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
graphs, label_dict = load_graphs(str(graph_path)) graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path)) info_dict = load_info(str(info_path))
self.graph_lists = graphs self.graph_lists = graphs
self.graph_labels = label_dict['labels'] self.graph_labels = label_dict["labels"]
self.max_num_node = info_dict['max_num_node'] self.max_num_node = info_dict["max_num_node"]
self.num_labels = info_dict['num_labels'] self.num_labels = info_dict["num_labels"]
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.bin'.format(self.name, self.hash)) graph_path = os.path.join(
info_path = os.path.join(self.save_path, 'legacy_tu_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, "legacy_tu_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, "legacy_tu_{}_{}.pkl".format(self.name, self.hash)
)
if os.path.exists(graph_path) and os.path.exists(info_path): if os.path.exists(graph_path) and os.path.exists(info_path):
return True return True
return False return False
...@@ -226,8 +278,9 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -226,8 +278,9 @@ class LegacyTUDataset(DGLBuiltinDataset):
return len(self.graph_lists) return len(self.graph_lists)
def _file_path(self, category): def _file_path(self, category):
return os.path.join(self.raw_path, self.name, return os.path.join(
"{}_{}.txt".format(self.name, category)) self.raw_path, self.name, "{}_{}.txt".format(self.name, category)
)
@staticmethod @staticmethod
def _idx_from_zero(idx_tensor): def _idx_from_zero(idx_tensor):
...@@ -242,14 +295,17 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -242,14 +295,17 @@ class LegacyTUDataset(DGLBuiltinDataset):
return one_hot_tensor return one_hot_tensor
def statistics(self): def statistics(self):
return self.graph_lists[0].ndata['feat'].shape[1],\ return (
self.num_labels,\ self.graph_lists[0].ndata["feat"].shape[1],
self.max_num_node self.num_labels,
self.max_num_node,
)
@property @property
def num_classes(self): def num_classes(self):
return int(self.num_labels) return int(self.num_labels)
class TUDataset(DGLBuiltinDataset): class TUDataset(DGLBuiltinDataset):
r""" r"""
TUDataset contains lots of graph kernel datasets for graph classification. TUDataset contains lots of graph kernel datasets for graph classification.
...@@ -322,25 +378,46 @@ class TUDataset(DGLBuiltinDataset): ...@@ -322,25 +378,46 @@ class TUDataset(DGLBuiltinDataset):
_url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip" _url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None): def __init__(
self,
name,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
url = self._url.format(name) url = self._url.format(name)
super(TUDataset, self).__init__(name=name, url=url, super(TUDataset, self).__init__(
raw_dir=raw_dir, force_reload=force_reload, name=name,
verbose=verbose, transform=transform) url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self): def process(self):
DS_edge_list = self._idx_from_zero( DS_edge_list = self._idx_from_zero(
loadtxt(self._file_path("A"), delimiter=",").astype(int)) loadtxt(self._file_path("A"), delimiter=",").astype(int)
)
DS_indicator = self._idx_from_zero( DS_indicator = self._idx_from_zero(
loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(
int
)
)
if os.path.exists(self._file_path("graph_labels")): if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_reset( DS_graph_labels = self._idx_reset(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_labels"), delimiter=",").astype(
int
)
)
self.num_labels = int(max(DS_graph_labels) + 1) self.num_labels = int(max(DS_graph_labels) + 1)
self.graph_labels = F.tensor(DS_graph_labels) self.graph_labels = F.tensor(DS_graph_labels)
elif os.path.exists(self._file_path("graph_attributes")): elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = loadtxt(self._file_path("graph_attributes"), delimiter=",").astype(float) DS_graph_labels = loadtxt(
self._file_path("graph_attributes"), delimiter=","
).astype(float)
self.num_labels = None self.num_labels = None
self.graph_labels = F.tensor(DS_graph_labels) self.graph_labels = F.tensor(DS_graph_labels)
else: else:
...@@ -358,19 +435,17 @@ class TUDataset(DGLBuiltinDataset): ...@@ -358,19 +435,17 @@ class TUDataset(DGLBuiltinDataset):
if len(node_idx[0]) > self.max_num_node: if len(node_idx[0]) > self.max_num_node:
self.max_num_node = len(node_idx[0]) self.max_num_node = len(node_idx[0])
self.attr_dict = { self.attr_dict = {
'node_labels': ('ndata', 'node_labels'), "node_labels": ("ndata", "node_labels"),
'node_attributes': ('ndata', 'node_attr'), "node_attributes": ("ndata", "node_attr"),
'edge_labels': ('edata', 'edge_labels'), "edge_labels": ("edata", "edge_labels"),
'edge_attributes': ('edata', 'node_labels'), "edge_attributes": ("edata", "node_labels"),
} }
for filename, field_name in self.attr_dict.items(): for filename, field_name in self.attr_dict.items():
try: try:
data = loadtxt(self._file_path(filename), data = loadtxt(self._file_path(filename), delimiter=",")
delimiter=',') if "label" in filename:
if 'label' in filename:
data = F.tensor(self._idx_from_zero(data)) data = F.tensor(self._idx_from_zero(data))
else: else:
data = F.tensor(data) data = F.tensor(data)
...@@ -381,28 +456,30 @@ class TUDataset(DGLBuiltinDataset): ...@@ -381,28 +456,30 @@ class TUDataset(DGLBuiltinDataset):
self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list] self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list]
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'tu_{}.bin'.format(self.name)) graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name))
info_path = os.path.join(self.save_path, 'tu_{}.pkl'.format(self.name)) info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
label_dict = {'labels': self.graph_labels} label_dict = {"labels": self.graph_labels}
info_dict = {'max_num_node': self.max_num_node, info_dict = {
'num_labels': self.num_labels} "max_num_node": self.max_num_node,
"num_labels": self.num_labels,
}
save_graphs(str(graph_path), self.graph_lists, label_dict) save_graphs(str(graph_path), self.graph_lists, label_dict)
save_info(str(info_path), info_dict) save_info(str(info_path), info_dict)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'tu_{}.bin'.format(self.name)) graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name))
info_path = os.path.join(self.save_path, 'tu_{}.pkl'.format(self.name)) info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
graphs, label_dict = load_graphs(str(graph_path)) graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path)) info_dict = load_info(str(info_path))
self.graph_lists = graphs self.graph_lists = graphs
self.graph_labels = label_dict['labels'] self.graph_labels = label_dict["labels"]
self.max_num_node = info_dict['max_num_node'] self.max_num_node = info_dict["max_num_node"]
self.num_labels = info_dict['num_labels'] self.num_labels = info_dict["num_labels"]
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'tu_{}.bin'.format(self.name)) graph_path = os.path.join(self.save_path, "tu_{}.bin".format(self.name))
info_path = os.path.join(self.save_path, 'tu_{}.pkl'.format(self.name)) info_path = os.path.join(self.save_path, "tu_{}.pkl".format(self.name))
if os.path.exists(graph_path) and os.path.exists(info_path): if os.path.exists(graph_path) and os.path.exists(info_path):
return True return True
return False return False
...@@ -431,8 +508,9 @@ class TUDataset(DGLBuiltinDataset): ...@@ -431,8 +508,9 @@ class TUDataset(DGLBuiltinDataset):
return len(self.graph_lists) return len(self.graph_lists)
def _file_path(self, category): def _file_path(self, category):
return os.path.join(self.raw_path, self.name, return os.path.join(
"{}_{}.txt".format(self.name, category)) self.raw_path, self.name, "{}_{}.txt".format(self.name, category)
)
@staticmethod @staticmethod
def _idx_from_zero(idx_tensor): def _idx_from_zero(idx_tensor):
...@@ -447,9 +525,11 @@ class TUDataset(DGLBuiltinDataset): ...@@ -447,9 +525,11 @@ class TUDataset(DGLBuiltinDataset):
return new_idx_tensor return new_idx_tensor
def statistics(self): def statistics(self):
return self.graph_lists[0].ndata['feat'].shape[1], \ return (
self.num_labels, \ self.graph_lists[0].ndata["feat"].shape[1],
self.max_num_node self.num_labels,
self.max_num_node,
)
@property @property
def num_classes(self): def num_classes(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment