Unverified Commit a208e886 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4680)



* [Misc] Black auto fix.

* fix pylint disable
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 29434e65
...@@ -6,14 +6,22 @@ https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip ...@@ -6,14 +6,22 @@ https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip
""" """
import os import os
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F
from .dgl_dataset import DGLBuiltinDataset
from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info, download, extract_archive
from ..utils import retry_method_with_fix
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from ..utils import retry_method_with_fix
from .dgl_dataset import DGLBuiltinDataset
from .utils import (
download,
extract_archive,
load_graphs,
load_info,
loadtxt,
save_graphs,
save_info,
)
class GINDataset(DGLBuiltinDataset): class GINDataset(DGLBuiltinDataset):
...@@ -81,12 +89,20 @@ class GINDataset(DGLBuiltinDataset): ...@@ -81,12 +89,20 @@ class GINDataset(DGLBuiltinDataset):
edata_schemes={}) edata_schemes={})
""" """
def __init__(self, name, self_loop, degree_as_nlabel=False, def __init__(
raw_dir=None, force_reload=False, verbose=False, transform=None): self,
name,
self_loop,
degree_as_nlabel=False,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
self._name = name # MUTAG self._name = name # MUTAG
gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip' gin_url = "https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip"
self.ds_name = 'nig' self.ds_name = "nig"
self.self_loop = self_loop self.self_loop = self_loop
self.graphs = [] self.graphs = []
...@@ -114,18 +130,23 @@ class GINDataset(DGLBuiltinDataset): ...@@ -114,18 +130,23 @@ class GINDataset(DGLBuiltinDataset):
self.nattrs_flag = False self.nattrs_flag = False
self.nlabels_flag = False self.nlabels_flag = False
super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel), super(GINDataset, self).__init__(
raw_dir=raw_dir, force_reload=force_reload, name=name,
verbose=verbose, transform=transform) url=gin_url,
hash_key=(name, self_loop, degree_as_nlabel),
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
@property @property
def raw_path(self): def raw_path(self):
return os.path.join(self.raw_dir, 'GINDataset') return os.path.join(self.raw_dir, "GINDataset")
def download(self): def download(self):
r""" Automatically download data and extract it. r"""Automatically download data and extract it."""
""" zip_file_path = os.path.join(self.raw_dir, "GINDataset.zip")
zip_file_path = os.path.join(self.raw_dir, 'GINDataset.zip')
download(self.url, path=zip_file_path) download(self.url, path=zip_file_path)
extract_archive(zip_file_path, self.raw_path) extract_archive(zip_file_path, self.raw_path)
...@@ -153,21 +174,26 @@ class GINDataset(DGLBuiltinDataset): ...@@ -153,21 +174,26 @@ class GINDataset(DGLBuiltinDataset):
return g, self.labels[idx] return g, self.labels[idx]
def _file_path(self): def _file_path(self):
return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name)) return os.path.join(
self.raw_dir,
"GINDataset",
"dataset",
self.name,
"{}.txt".format(self.name),
)
def process(self): def process(self):
""" Loads input dataset from dataset/NAME/NAME.txt file """Loads input dataset from dataset/NAME/NAME.txt file"""
"""
if self.verbose: if self.verbose:
print('loading data...') print("loading data...")
self.file = self._file_path() self.file = self._file_path()
with open(self.file, 'r') as f: with open(self.file, "r") as f:
# line_1 == N, total number of graphs # line_1 == N, total number of graphs
self.N = int(f.readline().strip()) self.N = int(f.readline().strip())
for i in range(self.N): for i in range(self.N):
if (i + 1) % 10 == 0 and self.verbose is True: if (i + 1) % 10 == 0 and self.verbose is True:
print('processing graph {}...'.format(i + 1)) print("processing graph {}...".format(i + 1))
grow = f.readline().strip().split() grow = f.readline().strip().split()
# line_2 == [n_nodes, l] is equal to # line_2 == [n_nodes, l] is equal to
...@@ -201,7 +227,7 @@ class GINDataset(DGLBuiltinDataset): ...@@ -201,7 +227,7 @@ class GINDataset(DGLBuiltinDataset):
nattr = [float(w) for w in nrow[tmp:]] nattr = [float(w) for w in nrow[tmp:]]
nattrs.append(nattr) nattrs.append(nattr)
else: else:
raise Exception('edge number is incorrect!') raise Exception("edge number is incorrect!")
# relabel nodes if it has labels # relabel nodes if it has labels
# if it doesn't have node labels, then every nrow[0]==0 # if it doesn't have node labels, then every nrow[0]==0
...@@ -221,17 +247,18 @@ class GINDataset(DGLBuiltinDataset): ...@@ -221,17 +247,18 @@ class GINDataset(DGLBuiltinDataset):
if (j + 1) % 10 == 0 and self.verbose is True: if (j + 1) % 10 == 0 and self.verbose is True:
print( print(
'processing node {} of graph {}...'.format( "processing node {} of graph {}...".format(
j + 1, i + 1)) j + 1, i + 1
print('this node has {} edgs.'.format( )
nrow[1])) )
print("this node has {} edgs.".format(nrow[1]))
if nattrs != []: if nattrs != []:
nattrs = np.stack(nattrs) nattrs = np.stack(nattrs)
g.ndata['attr'] = F.tensor(nattrs, F.float32) g.ndata["attr"] = F.tensor(nattrs, F.float32)
self.nattrs_flag = True self.nattrs_flag = True
g.ndata['label'] = F.tensor(nlabels) g.ndata["label"] = F.tensor(nlabels)
if len(self.nlabel_dict) > 1: if len(self.nlabel_dict) > 1:
self.nlabels_flag = True self.nlabels_flag = True
...@@ -247,51 +274,59 @@ class GINDataset(DGLBuiltinDataset): ...@@ -247,51 +274,59 @@ class GINDataset(DGLBuiltinDataset):
# if no attr # if no attr
if not self.nattrs_flag: if not self.nattrs_flag:
if self.verbose: if self.verbose:
print('there are no node features in this dataset!') print("there are no node features in this dataset!")
# generate node attr by node degree # generate node attr by node degree
if self.degree_as_nlabel: if self.degree_as_nlabel:
if self.verbose: if self.verbose:
print('generate node features by node degree...') print("generate node features by node degree...")
for g in self.graphs: for g in self.graphs:
# actually this label shouldn't be updated # actually this label shouldn't be updated
# in case users want to keep it # in case users want to keep it
# but usually no features means no labels, fine. # but usually no features means no labels, fine.
g.ndata['label'] = g.in_degrees() g.ndata["label"] = g.in_degrees()
# extracting unique node labels # extracting unique node labels
# in case the labels/degrees are not continuous number # in case the labels/degrees are not continuous number
nlabel_set = set([]) nlabel_set = set([])
for g in self.graphs: for g in self.graphs:
nlabel_set = nlabel_set.union( nlabel_set = nlabel_set.union(
set([F.as_scalar(nl) for nl in g.ndata['label']])) set([F.as_scalar(nl) for nl in g.ndata["label"]])
)
nlabel_set = list(nlabel_set) nlabel_set = list(nlabel_set)
is_label_valid = all([label in self.nlabel_dict for label in nlabel_set]) is_label_valid = all(
if is_label_valid and len(nlabel_set) == np.max(nlabel_set) + 1 and np.min(nlabel_set) == 0: [label in self.nlabel_dict for label in nlabel_set]
)
if (
is_label_valid
and len(nlabel_set) == np.max(nlabel_set) + 1
and np.min(nlabel_set) == 0
):
# Note this is different from the author's implementation. In weihua916's implementation, # Note this is different from the author's implementation. In weihua916's implementation,
# the labels are relabeled anyway. But here we didn't relabel it if the labels are contiguous # the labels are relabeled anyway. But here we didn't relabel it if the labels are contiguous
# to make it consistent with the original dataset # to make it consistent with the original dataset
label2idx = self.nlabel_dict label2idx = self.nlabel_dict
else: else:
label2idx = { label2idx = {nlabel_set[i]: i for i in range(len(nlabel_set))}
nlabel_set[i]: i
for i in range(len(nlabel_set))
}
# generate node attr by node label # generate node attr by node label
for g in self.graphs: for g in self.graphs:
attr = np.zeros(( attr = np.zeros((g.number_of_nodes(), len(label2idx)))
g.number_of_nodes(), len(label2idx))) attr[
attr[range(g.number_of_nodes()), [label2idx[nl] range(g.number_of_nodes()),
for nl in F.asnumpy(g.ndata['label']).tolist()]] = 1 [
g.ndata['attr'] = F.tensor(attr, F.float32) label2idx[nl]
for nl in F.asnumpy(g.ndata["label"]).tolist()
],
] = 1
g.ndata["attr"] = F.tensor(attr, F.float32)
# after load, get the #classes and #dim # after load, get the #classes and #dim
self.gclasses = len(self.glabel_dict) self.gclasses = len(self.glabel_dict)
self.nclasses = len(self.nlabel_dict) self.nclasses = len(self.nlabel_dict)
self.eclasses = len(self.elabel_dict) self.eclasses = len(self.elabel_dict)
self.dim_nfeats = len(self.graphs[0].ndata['attr'][0]) self.dim_nfeats = len(self.graphs[0].ndata["attr"][0])
if self.verbose: if self.verbose:
print('Done.') print("Done.")
print( print(
""" """
-------- Data Statistics --------' -------- Data Statistics --------'
...@@ -306,64 +341,83 @@ class GINDataset(DGLBuiltinDataset): ...@@ -306,64 +341,83 @@ class GINDataset(DGLBuiltinDataset):
Avg. of #Edges: %.2f Avg. of #Edges: %.2f
Graph Relabeled: %s Graph Relabeled: %s
Node Relabeled: %s Node Relabeled: %s
Degree Relabeled(If degree_as_nlabel=True): %s \n """ % ( Degree Relabeled(If degree_as_nlabel=True): %s \n """
self.N, self.gclasses, self.n, self.nclasses, % (
self.dim_nfeats, self.m, self.eclasses, self.N,
self.n / self.N, self.m / self.N, self.glabel_dict, self.gclasses,
self.nlabel_dict, self.ndegree_dict)) self.n,
self.nclasses,
self.dim_nfeats,
self.m,
self.eclasses,
self.n / self.N,
self.m / self.N,
self.glabel_dict,
self.nlabel_dict,
self.ndegree_dict,
)
)
def save(self): def save(self):
graph_path = os.path.join( graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash)) self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join( info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
label_dict = {'labels': self.labels} )
info_dict = {'N': self.N, label_dict = {"labels": self.labels}
'n': self.n, info_dict = {
'm': self.m, "N": self.N,
'self_loop': self.self_loop, "n": self.n,
'gclasses': self.gclasses, "m": self.m,
'nclasses': self.nclasses, "self_loop": self.self_loop,
'eclasses': self.eclasses, "gclasses": self.gclasses,
'dim_nfeats': self.dim_nfeats, "nclasses": self.nclasses,
'degree_as_nlabel': self.degree_as_nlabel, "eclasses": self.eclasses,
'glabel_dict': self.glabel_dict, "dim_nfeats": self.dim_nfeats,
'nlabel_dict': self.nlabel_dict, "degree_as_nlabel": self.degree_as_nlabel,
'elabel_dict': self.elabel_dict, "glabel_dict": self.glabel_dict,
'ndegree_dict': self.ndegree_dict} "nlabel_dict": self.nlabel_dict,
"elabel_dict": self.elabel_dict,
"ndegree_dict": self.ndegree_dict,
}
save_graphs(str(graph_path), self.graphs, label_dict) save_graphs(str(graph_path), self.graphs, label_dict)
save_info(str(info_path), info_dict) save_info(str(info_path), info_dict)
def load(self): def load(self):
graph_path = os.path.join( graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash)) self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join( info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
)
graphs, label_dict = load_graphs(str(graph_path)) graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path)) info_dict = load_info(str(info_path))
self.graphs = graphs self.graphs = graphs
self.labels = label_dict['labels'] self.labels = label_dict["labels"]
self.N = info_dict['N'] self.N = info_dict["N"]
self.n = info_dict['n'] self.n = info_dict["n"]
self.m = info_dict['m'] self.m = info_dict["m"]
self.self_loop = info_dict['self_loop'] self.self_loop = info_dict["self_loop"]
self.gclasses = info_dict['gclasses'] self.gclasses = info_dict["gclasses"]
self.nclasses = info_dict['nclasses'] self.nclasses = info_dict["nclasses"]
self.eclasses = info_dict['eclasses'] self.eclasses = info_dict["eclasses"]
self.dim_nfeats = info_dict['dim_nfeats'] self.dim_nfeats = info_dict["dim_nfeats"]
self.glabel_dict = info_dict['glabel_dict'] self.glabel_dict = info_dict["glabel_dict"]
self.nlabel_dict = info_dict['nlabel_dict'] self.nlabel_dict = info_dict["nlabel_dict"]
self.elabel_dict = info_dict['elabel_dict'] self.elabel_dict = info_dict["elabel_dict"]
self.ndegree_dict = info_dict['ndegree_dict'] self.ndegree_dict = info_dict["ndegree_dict"]
self.degree_as_nlabel = info_dict['degree_as_nlabel'] self.degree_as_nlabel = info_dict["degree_as_nlabel"]
def has_cache(self): def has_cache(self):
graph_path = os.path.join( graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash)) self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join( info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
)
if os.path.exists(graph_path) and os.path.exists(info_path): if os.path.exists(graph_path) and os.path.exists(info_path):
return True return True
return False return False
......
"""For Graph Serialization""" """For Graph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
import os import os
from ..base import dgl_warning, DGLError
from ..heterograph import DGLHeteroGraph
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object
from ..base import DGLError, dgl_warning
from ..heterograph import DGLHeteroGraph
from .heterograph_serialize import save_heterographs from .heterograph_serialize import save_heterographs
_init_api("dgl.data.graph_serialize") _init_api("dgl.data.graph_serialize")
__all__ = ['save_graphs', "load_graphs", "load_labels"] __all__ = ["save_graphs", "load_graphs", "load_labels"]
@register_object("graph_serialize.StorageMetaData") @register_object("graph_serialize.StorageMetaData")
...@@ -26,15 +28,18 @@ class StorageMetaData(ObjectBase): ...@@ -26,15 +28,18 @@ class StorageMetaData(ObjectBase):
def is_local_path(filepath): def is_local_path(filepath):
return not (filepath.startswith("hdfs://") or return not (
filepath.startswith("viewfs://") or filepath.startswith("hdfs://")
filepath.startswith("s3://")) or filepath.startswith("viewfs://")
or filepath.startswith("s3://")
)
def check_local_file_exists(filename): def check_local_file_exists(filename):
if is_local_path(filename) and not os.path.exists(filename): if is_local_path(filename) and not os.path.exists(filename):
raise DGLError("File {} does not exist.".format(filename)) raise DGLError("File {} does not exist.".format(filename))
@register_object("graph_serialize.GraphData") @register_object("graph_serialize.GraphData")
class GraphData(ObjectBase): class GraphData(ObjectBase):
"""GraphData Object""" """GraphData Object"""
...@@ -43,7 +48,9 @@ class GraphData(ObjectBase): ...@@ -43,7 +48,9 @@ class GraphData(ObjectBase):
def create(g): def create(g):
"""Create GraphData""" """Create GraphData"""
# TODO(zihao): support serialize batched graph in the future. # TODO(zihao): support serialize batched graph in the future.
assert g.batch_size == 1, "Batched DGLGraph is not supported for serialization" assert (
g.batch_size == 1
), "Batched DGLGraph is not supported for serialization"
ghandle = g._graph ghandle = g._graph
if len(g.ndata) != 0: if len(g.ndata) != 0:
node_tensors = dict() node_tensors = dict()
...@@ -64,8 +71,8 @@ class GraphData(ObjectBase): ...@@ -64,8 +71,8 @@ class GraphData(ObjectBase):
def get_graph(self): def get_graph(self):
"""Get DGLHeteroGraph from GraphData""" """Get DGLHeteroGraph from GraphData"""
ghandle = _CAPI_GDataGraphHandle(self) ghandle = _CAPI_GDataGraphHandle(self)
hgi =_CAPI_DGLAsHeteroGraph(ghandle) hgi = _CAPI_DGLAsHeteroGraph(ghandle)
g = DGLHeteroGraph(hgi, ['_U'], ['_E']) g = DGLHeteroGraph(hgi, ["_U"], ["_E"])
node_tensors_items = _CAPI_GDataNodeTensors(self).items() node_tensors_items = _CAPI_GDataNodeTensors(self).items()
edge_tensors_items = _CAPI_GDataEdgeTensors(self).items() edge_tensors_items = _CAPI_GDataEdgeTensors(self).items()
for k, v in node_tensors_items: for k, v in node_tensors_items:
...@@ -120,18 +127,22 @@ def save_graphs(filename, g_list, labels=None): ...@@ -120,18 +127,22 @@ def save_graphs(filename, g_list, labels=None):
# if it is local file, do some sanity check # if it is local file, do some sanity check
if is_local_path(filename): if is_local_path(filename):
if os.path.isdir(filename): if os.path.isdir(filename):
raise DGLError("Filename {} is an existing directory.".format(filename)) raise DGLError(
"Filename {} is an existing directory.".format(filename)
)
f_path = os.path.dirname(filename) f_path = os.path.dirname(filename)
if f_path and not os.path.exists(f_path): if f_path and not os.path.exists(f_path):
os.makedirs(f_path) os.makedirs(f_path)
g_sample = g_list[0] if isinstance(g_list, list) else g_list g_sample = g_list[0] if isinstance(g_list, list) else g_list
if type(g_sample) == DGLHeteroGraph: # Doesn't support DGLHeteroGraph's derived class if (
type(g_sample) == DGLHeteroGraph
): # Doesn't support DGLHeteroGraph's derived class
save_heterographs(filename, g_list, labels) save_heterographs(filename, g_list, labels)
else: else:
raise DGLError( raise DGLError(
"Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs.") "Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs."
)
def load_graphs(filename, idx_list=None): def load_graphs(filename, idx_list=None):
...@@ -176,7 +187,8 @@ def load_graphs(filename, idx_list=None): ...@@ -176,7 +187,8 @@ def load_graphs(filename, idx_list=None):
if version == 1: if version == 1:
dgl_warning( dgl_warning(
"You are loading a graph file saved by old version of dgl. \ "You are loading a graph file saved by old version of dgl. \
Please consider saving it again with the current format.") Please consider saving it again with the current format."
)
return load_graph_v1(filename, idx_list) return load_graph_v1(filename, idx_list)
elif version == 2: elif version == 2:
return load_graph_v2(filename, idx_list) return load_graph_v2(filename, idx_list)
...@@ -195,7 +207,7 @@ def load_graph_v2(filename, idx_list=None): ...@@ -195,7 +207,7 @@ def load_graph_v2(filename, idx_list=None):
def load_graph_v1(filename, idx_list=None): def load_graph_v1(filename, idx_list=None):
""""Internal functions for loading DGLGraphs (V0).""" """ "Internal functions for loading DGLGraphs (V0)."""
if idx_list is None: if idx_list is None:
idx_list = [] idx_list = []
assert isinstance(idx_list, list) assert isinstance(idx_list, list)
...@@ -206,6 +218,7 @@ def load_graph_v1(filename, idx_list=None): ...@@ -206,6 +218,7 @@ def load_graph_v1(filename, idx_list=None):
return [gdata.get_graph() for gdata in metadata.graph_data], label_dict return [gdata.get_graph() for gdata in metadata.graph_data], label_dict
def load_labels(filename): def load_labels(filename):
""" """
Load label dict from file Load label dict from file
......
"""For HeteroGraph Serialization""" """For HeteroGraph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
from ..heterograph import DGLHeteroGraph
from ..frame import Frame
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object
from ..container import convert_to_strmap from ..container import convert_to_strmap
from ..frame import Frame
from ..heterograph import DGLHeteroGraph
_init_api("dgl.data.heterograph_serialize") _init_api("dgl.data.heterograph_serialize")
...@@ -24,9 +25,14 @@ def save_heterographs(filename, g_list, labels): ...@@ -24,9 +25,14 @@ def save_heterographs(filename, g_list, labels):
labels = {} labels = {}
if isinstance(g_list, DGLHeteroGraph): if isinstance(g_list, DGLHeteroGraph):
g_list = [g_list] g_list = [g_list]
assert all([type(g) == DGLHeteroGraph for g in g_list]), "Invalid DGLHeteroGraph in g_list argument" assert all(
[type(g) == DGLHeteroGraph for g in g_list]
), "Invalid DGLHeteroGraph in g_list argument"
gdata_list = [HeteroGraphData.create(g) for g in g_list] gdata_list = [HeteroGraphData.create(g) for g in g_list]
_CAPI_SaveHeteroGraphData(filename, gdata_list, tensor_dict_to_ndarray_dict(labels)) _CAPI_SaveHeteroGraphData(
filename, gdata_list, tensor_dict_to_ndarray_dict(labels)
)
@register_object("heterograph_serialize.HeteroGraphData") @register_object("heterograph_serialize.HeteroGraphData")
class HeteroGraphData(ObjectBase): class HeteroGraphData(ObjectBase):
...@@ -40,7 +46,9 @@ class HeteroGraphData(ObjectBase): ...@@ -40,7 +46,9 @@ class HeteroGraphData(ObjectBase):
edata_list.append(tensor_dict_to_ndarray_dict(g.edges[etype].data)) edata_list.append(tensor_dict_to_ndarray_dict(g.edges[etype].data))
for ntype in g.ntypes: for ntype in g.ntypes:
ndata_list.append(tensor_dict_to_ndarray_dict(g.nodes[ntype].data)) ndata_list.append(tensor_dict_to_ndarray_dict(g.nodes[ntype].data))
return _CAPI_MakeHeteroGraphData(g._graph, ndata_list, edata_list, g.ntypes, g.etypes) return _CAPI_MakeHeteroGraphData(
g._graph, ndata_list, edata_list, g.ntypes, g.etypes
)
def get_graph(self): def get_graph(self):
ntensor_list = list(_CAPI_GetNDataFromHeteroGraphData(self)) ntensor_list = list(_CAPI_GetNDataFromHeteroGraphData(self))
...@@ -51,11 +59,17 @@ class HeteroGraphData(ObjectBase): ...@@ -51,11 +59,17 @@ class HeteroGraphData(ObjectBase):
nframes = [] nframes = []
eframes = [] eframes = []
for ntid, ntensor in enumerate(ntensor_list): for ntid, ntensor in enumerate(ntensor_list):
ndict = {ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i+1]) for i in range(0, len(ntensor), 2)} ndict = {
ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i + 1])
for i in range(0, len(ntensor), 2)
}
nframes.append(Frame(ndict, num_rows=gidx.number_of_nodes(ntid))) nframes.append(Frame(ndict, num_rows=gidx.number_of_nodes(ntid)))
for etid, etensor in enumerate(etensor_list): for etid, etensor in enumerate(etensor_list):
edict = {etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i+1]) for i in range(0, len(etensor), 2)} edict = {
etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i + 1])
for i in range(0, len(etensor), 2)
}
eframes.append(Frame(edict, num_rows=gidx.number_of_edges(etid))) eframes.append(Frame(edict, num_rows=gidx.number_of_edges(etid)))
return DGLHeteroGraph(gidx, ntype_names, etype_names, nframes, eframes) return DGLHeteroGraph(gidx, ntype_names, etype_names, nframes, eframes)
"""ICEWS18 dataset for temporal graph""" """ICEWS18 dataset for temporal graph"""
import numpy as np
import os import os
from .dgl_dataset import DGLBuiltinDataset import numpy as np
from .utils import loadtxt, _get_dgl_url, save_graphs, load_graphs
from ..convert import graph as dgl_graph
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs, loadtxt, save_graphs
class ICEWS18Dataset(DGLBuiltinDataset): class ICEWS18Dataset(DGLBuiltinDataset):
r""" ICEWS18 dataset for temporal graph r"""ICEWS18 dataset for temporal graph
Integrated Crisis Early Warning System (ICEWS18) Integrated Crisis Early Warning System (ICEWS18)
...@@ -65,21 +66,33 @@ class ICEWS18Dataset(DGLBuiltinDataset): ...@@ -65,21 +66,33 @@ class ICEWS18Dataset(DGLBuiltinDataset):
.... ....
>>> >>>
""" """
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None):
def __init__(
self,
mode="train",
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
mode = mode.lower() mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid" assert mode in ["train", "valid", "test"], "Mode not valid"
self.mode = mode self.mode = mode
_url = _get_dgl_url('dataset/icews18.zip') _url = _get_dgl_url("dataset/icews18.zip")
super(ICEWS18Dataset, self).__init__(name='ICEWS18', super(ICEWS18Dataset, self).__init__(
url=_url, name="ICEWS18",
raw_dir=raw_dir, url=_url,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
data = loadtxt(os.path.join(self.save_path, '{}.txt'.format(self.mode)), data = loadtxt(
delimiter='\t').astype(np.int64) os.path.join(self.save_path, "{}.txt".format(self.mode)),
delimiter="\t",
).astype(np.int64)
num_nodes = 23033 num_nodes = 23033
# The source code is not released, but the paper indicates there're # The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples. # totally 137 samples. The cutoff below has exactly 137 samples.
...@@ -92,23 +105,31 @@ class ICEWS18Dataset(DGLBuiltinDataset): ...@@ -92,23 +105,31 @@ class ICEWS18Dataset(DGLBuiltinDataset):
edges = data[row_mask][:, [0, 2]] edges = data[row_mask][:, [0, 2]]
rate = data[row_mask][:, 1] rate = data[row_mask][:, 1]
g = dgl_graph((edges[:, 0], edges[:, 1])) g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64']) g.edata["rel_type"] = F.tensor(
rate.reshape(-1, 1), dtype=F.data_type_dict["int64"]
)
self._graphs.append(g) self._graphs.append(g)
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode)) graph_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
return os.path.exists(graph_path) return os.path.exists(graph_path)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode)) graph_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
save_graphs(graph_path, self._graphs) save_graphs(graph_path, self._graphs)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode)) graph_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
self._graphs = load_graphs(graph_path)[0] self._graphs = load_graphs(graph_path)[0]
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph by index r"""Get graph by index
Parameters Parameters
---------- ----------
......
"""A mini synthetic dataset for graph classification benchmark.""" """A mini synthetic dataset for graph classification benchmark."""
import math, os import math
import os
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from .dgl_dataset import DGLDataset
from .utils import save_graphs, load_graphs, makedirs
from .. import backend as F from .. import backend as F
from ..convert import from_networkx from ..convert import from_networkx
from ..transforms import add_self_loop from ..transforms import add_self_loop
from .dgl_dataset import DGLDataset
from .utils import load_graphs, makedirs, save_graphs
__all__ = ["MiniGCDataset"]
__all__ = ['MiniGCDataset']
class MiniGCDataset(DGLDataset): class MiniGCDataset(DGLDataset):
"""The synthetic graph classification dataset class. """The synthetic graph classification dataset class.
...@@ -78,17 +81,30 @@ class MiniGCDataset(DGLDataset): ...@@ -78,17 +81,30 @@ class MiniGCDataset(DGLDataset):
edata_schemes={}) edata_schemes={})
""" """
def __init__(self, num_graphs, min_num_v, max_num_v, seed=0, def __init__(
save_graph=True, force_reload=False, verbose=False, transform=None): self,
num_graphs,
min_num_v,
max_num_v,
seed=0,
save_graph=True,
force_reload=False,
verbose=False,
transform=None,
):
self.num_graphs = num_graphs self.num_graphs = num_graphs
self.min_num_v = min_num_v self.min_num_v = min_num_v
self.max_num_v = max_num_v self.max_num_v = max_num_v
self.seed = seed self.seed = seed
self.save_graph = save_graph self.save_graph = save_graph
super(MiniGCDataset, self).__init__(name="minigc", hash_key=(num_graphs, min_num_v, max_num_v, seed), super(MiniGCDataset, self).__init__(
force_reload=force_reload, name="minigc",
verbose=verbose, transform=transform) hash_key=(num_graphs, min_num_v, max_num_v, seed),
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self): def process(self):
self.graphs = [] self.graphs = []
...@@ -119,7 +135,9 @@ class MiniGCDataset(DGLDataset): ...@@ -119,7 +135,9 @@ class MiniGCDataset(DGLDataset):
return g, self.labels[idx] return g, self.labels[idx]
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash)) graph_path = os.path.join(
self.save_path, "dgl_graph_{}.bin".format(self.hash)
)
if os.path.exists(graph_path): if os.path.exists(graph_path):
return True return True
...@@ -128,13 +146,17 @@ class MiniGCDataset(DGLDataset): ...@@ -128,13 +146,17 @@ class MiniGCDataset(DGLDataset):
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
if self.save_graph: if self.save_graph:
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash)) graph_path = os.path.join(
save_graphs(str(graph_path), self.graphs, {'labels': self.labels}) self.save_path, "dgl_graph_{}.bin".format(self.hash)
)
save_graphs(str(graph_path), self.graphs, {"labels": self.labels})
def load(self): def load(self):
graphs, label_dict = load_graphs(os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))) graphs, label_dict = load_graphs(
os.path.join(self.save_path, "dgl_graph_{}.bin".format(self.hash))
)
self.graphs = graphs self.graphs = graphs
self.labels = label_dict['labels'] self.labels = label_dict["labels"]
@property @property
def num_classes(self): def num_classes(self):
...@@ -199,9 +221,11 @@ class MiniGCDataset(DGLDataset): ...@@ -199,9 +221,11 @@ class MiniGCDataset(DGLDataset):
def _gen_grid(self, n): def _gen_grid(self, n):
for _ in range(n): for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v) num_v = np.random.randint(self.min_num_v, self.max_num_v)
assert num_v >= 4, 'We require a grid graph to contain at least two ' \ assert num_v >= 4, (
'rows and two columns, thus 4 nodes, got {:d} ' \ "We require a grid graph to contain at least two "
'nodes'.format(num_v) "rows and two columns, thus 4 nodes, got {:d} "
"nodes".format(num_v)
)
n_rows = np.random.randint(2, num_v // 2) n_rows = np.random.randint(2, num_v // 2)
n_cols = num_v // n_rows n_cols = num_v // n_rows
g = nx.grid_graph([n_rows, n_cols]) g = nx.grid_graph([n_rows, n_cols])
......
""" PPIDataset for inductive learning. """ """ PPIDataset for inductive learning. """
import json import json
import numpy as np import os
import networkx as nx import networkx as nx
import numpy as np
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
import os
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, save_graphs, save_info, load_info, load_graphs
from .. import backend as F from .. import backend as F
from ..convert import from_networkx from ..convert import from_networkx
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs, load_info, save_graphs, save_info
class PPIDataset(DGLBuiltinDataset): class PPIDataset(DGLBuiltinDataset):
r""" Protein-Protein Interaction dataset for inductive node classification r"""Protein-Protein Interaction dataset for inductive node classification
A toy Protein-Protein Interaction network dataset. The dataset contains A toy Protein-Protein Interaction network dataset. The dataset contains
24 graphs. The average number of nodes per graph is 2372. Each node has 24 graphs. The average number of nodes per graph is 2372. Each node has
...@@ -65,36 +66,55 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -65,36 +66,55 @@ class PPIDataset(DGLBuiltinDataset):
.... # your code here .... # your code here
>>> >>>
""" """
def __init__(self, mode='train', raw_dir=None, force_reload=False,
verbose=False, transform=None): def __init__(
assert mode in ['train', 'valid', 'test'] self,
mode="train",
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
assert mode in ["train", "valid", "test"]
self.mode = mode self.mode = mode
_url = _get_dgl_url('dataset/ppi.zip') _url = _get_dgl_url("dataset/ppi.zip")
super(PPIDataset, self).__init__(name='ppi', super(PPIDataset, self).__init__(
url=_url, name="ppi",
raw_dir=raw_dir, url=_url,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
graph_file = os.path.join(self.save_path, '{}_graph.json'.format(self.mode)) graph_file = os.path.join(
label_file = os.path.join(self.save_path, '{}_labels.npy'.format(self.mode)) self.save_path, "{}_graph.json".format(self.mode)
feat_file = os.path.join(self.save_path, '{}_feats.npy'.format(self.mode)) )
graph_id_file = os.path.join(self.save_path, '{}_graph_id.npy'.format(self.mode)) label_file = os.path.join(
self.save_path, "{}_labels.npy".format(self.mode)
)
feat_file = os.path.join(
self.save_path, "{}_feats.npy".format(self.mode)
)
graph_id_file = os.path.join(
self.save_path, "{}_graph_id.npy".format(self.mode)
)
g_data = json.load(open(graph_file)) g_data = json.load(open(graph_file))
self._labels = np.load(label_file) self._labels = np.load(label_file)
self._feats = np.load(feat_file) self._feats = np.load(feat_file)
self.graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data))) self.graph = from_networkx(
nx.DiGraph(json_graph.node_link_graph(g_data))
)
graph_id = np.load(graph_id_file) graph_id = np.load(graph_id_file)
# lo, hi means the range of graph ids for different portion of the dataset, # lo, hi means the range of graph ids for different portion of the dataset,
# 20 graphs for training, 2 for validation and 2 for testing. # 20 graphs for training, 2 for validation and 2 for testing.
lo, hi = 1, 21 lo, hi = 1, 21
if self.mode == 'valid': if self.mode == "valid":
lo, hi = 21, 23 lo, hi = 21, 23
elif self.mode == 'test': elif self.mode == "test":
lo, hi = 23, 25 lo, hi = 23, 25
graph_masks = [] graph_masks = []
...@@ -103,34 +123,60 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -103,34 +123,60 @@ class PPIDataset(DGLBuiltinDataset):
g_mask = np.where(graph_id == g_id)[0] g_mask = np.where(graph_id == g_id)[0]
graph_masks.append(g_mask) graph_masks.append(g_mask)
g = self.graph.subgraph(g_mask) g = self.graph.subgraph(g_mask)
g.ndata['feat'] = F.tensor(self._feats[g_mask], dtype=F.data_type_dict['float32']) g.ndata["feat"] = F.tensor(
g.ndata['label'] = F.tensor(self._labels[g_mask], dtype=F.data_type_dict['float32']) self._feats[g_mask], dtype=F.data_type_dict["float32"]
)
g.ndata["label"] = F.tensor(
self._labels[g_mask], dtype=F.data_type_dict["float32"]
)
self.graphs.append(g) self.graphs.append(g)
def has_cache(self): def has_cache(self):
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode)) graph_list_path = os.path.join(
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode)) self.save_path, "{}_dgl_graph_list.bin".format(self.mode)
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode)) )
return os.path.exists(graph_list_path) and os.path.exists(g_path) and os.path.exists(info_path) g_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
return (
os.path.exists(graph_list_path)
and os.path.exists(g_path)
and os.path.exists(info_path)
)
def save(self): def save(self):
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode)) graph_list_path = os.path.join(
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode)) self.save_path, "{}_dgl_graph_list.bin".format(self.mode)
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode)) )
g_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
save_graphs(graph_list_path, self.graphs) save_graphs(graph_list_path, self.graphs)
save_graphs(g_path, self.graph) save_graphs(g_path, self.graph)
save_info(info_path, {'labels': self._labels, 'feats': self._feats}) save_info(info_path, {"labels": self._labels, "feats": self._feats})
def load(self): def load(self):
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode)) graph_list_path = os.path.join(
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode)) self.save_path, "{}_dgl_graph_list.bin".format(self.mode)
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode)) )
g_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
self.graphs = load_graphs(graph_list_path)[0] self.graphs = load_graphs(graph_list_path)[0]
g, _ = load_graphs(g_path) g, _ = load_graphs(g_path)
self.graph = g[0] self.graph = g[0]
info = load_info(info_path) info = load_info(info_path)
self._labels = info['labels'] self._labels = info["labels"]
self._feats = info['feats'] self._feats = info["feats"]
@property @property
def num_labels(self): def num_labels(self):
...@@ -163,8 +209,7 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -163,8 +209,7 @@ class PPIDataset(DGLBuiltinDataset):
class LegacyPPIDataset(PPIDataset): class LegacyPPIDataset(PPIDataset):
"""Legacy version of PPI Dataset """Legacy version of PPI Dataset"""
"""
def __getitem__(self, item): def __getitem__(self, item):
"""Get the item^th sample. """Get the item^th sample.
...@@ -183,4 +228,4 @@ class LegacyPPIDataset(PPIDataset): ...@@ -183,4 +228,4 @@ class LegacyPPIDataset(PPIDataset):
g = self.graphs[item] g = self.graphs[item]
else: else:
g = self._transform(self.graphs[item]) g = self._transform(self.graphs[item])
return g, g.ndata['feat'], g.ndata['label'] return g, g.ndata["feat"], g.ndata["label"]
...@@ -3,30 +3,39 @@ Datasets from "A Collection of Benchmark Datasets for ...@@ -3,30 +3,39 @@ Datasets from "A Collection of Benchmark Datasets for
Systematic Evaluations of Machine Learning on Systematic Evaluations of Machine Learning on
the Semantic Web" the Semantic Web"
""" """
import os
from collections import OrderedDict
import itertools
import abc import abc
import itertools
import os
import re import re
from collections import OrderedDict
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import dgl import dgl
import dgl.backend as F import dgl.backend as F
from .dgl_dataset import DGLBuiltinDataset
from .utils import save_graphs, load_graphs, save_info, load_info, _get_dgl_url
from .utils import generate_mask_tensor, idx2mask
__all__ = ['AIFBDataset', 'MUTAGDataset', 'BGSDataset', 'AMDataset'] from .dgl_dataset import DGLBuiltinDataset
from .utils import (
_get_dgl_url,
generate_mask_tensor,
idx2mask,
load_graphs,
load_info,
save_graphs,
save_info,
)
__all__ = ["AIFBDataset", "MUTAGDataset", "BGSDataset", "AMDataset"]
# Dictionary for renaming reserved node/edge type names to the ones # Dictionary for renaming reserved node/edge type names to the ones
# that are allowed by nn.Module. # that are allowed by nn.Module.
RENAME_DICT = { RENAME_DICT = {
'type' : 'rdftype', "type": "rdftype",
'rev-type' : 'rev-rdftype', "rev-type": "rev-rdftype",
} }
class Entity: class Entity:
"""Class for entities """Class for entities
Parameters Parameters
...@@ -36,12 +45,14 @@ class Entity: ...@@ -36,12 +45,14 @@ class Entity:
cls : str cls : str
Type of this entity Type of this entity
""" """
def __init__(self, e_id, cls): def __init__(self, e_id, cls):
self.id = e_id self.id = e_id
self.cls = cls self.cls = cls
def __str__(self): def __str__(self):
return '{}/{}'.format(self.cls, self.id) return "{}/{}".format(self.cls, self.id)
class Relation: class Relation:
"""Class for relations """Class for relations
...@@ -50,12 +61,14 @@ class Relation: ...@@ -50,12 +61,14 @@ class Relation:
cls : str cls : str
Type of this relation Type of this relation
""" """
def __init__(self, cls): def __init__(self, cls):
self.cls = cls self.cls = cls
def __str__(self): def __str__(self):
return str(self.cls) return str(self.cls)
class RDFGraphDataset(DGLBuiltinDataset): class RDFGraphDataset(DGLBuiltinDataset):
"""Base graph dataset class from RDF tuples. """Base graph dataset class from RDF tuples.
...@@ -101,22 +114,31 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -101,22 +114,31 @@ class RDFGraphDataset(DGLBuiltinDataset):
a transformed version. The :class:`~dgl.DGLGraph` object will be a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access. transformed before every access.
""" """
def __init__(self, name, url, predict_category,
print_every=10000, def __init__(
insert_reverse=True, self,
raw_dir=None, name,
force_reload=False, url,
verbose=True, predict_category,
transform=None): print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
self._insert_reverse = insert_reverse self._insert_reverse = insert_reverse
self._print_every = print_every self._print_every = print_every
self._predict_category = predict_category self._predict_category = predict_category
super(RDFGraphDataset, self).__init__(name, url, super(RDFGraphDataset, self).__init__(
raw_dir=raw_dir, name,
force_reload=force_reload, url,
verbose=verbose, raw_dir=raw_dir,
transform=transform) force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self): def process(self):
raw_tuples = self.load_raw_tuples(self.raw_path) raw_tuples = self.load_raw_tuples(self.raw_path)
...@@ -135,17 +157,18 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -135,17 +157,18 @@ class RDFGraphDataset(DGLBuiltinDataset):
Loaded rdf data Loaded rdf data
""" """
import rdflib as rdf import rdflib as rdf
raw_rdf_graphs = [] raw_rdf_graphs = []
for _, filename in enumerate(os.listdir(root_path)): for _, filename in enumerate(os.listdir(root_path)):
fmt = None fmt = None
if filename.endswith('nt'): if filename.endswith("nt"):
fmt = 'nt' fmt = "nt"
elif filename.endswith('n3'): elif filename.endswith("n3"):
fmt = 'n3' fmt = "n3"
if fmt is None: if fmt is None:
continue continue
g = rdf.Graph() g = rdf.Graph()
print('Parsing file %s ...' % filename) print("Parsing file %s ..." % filename)
g.parse(os.path.join(root_path, filename), format=fmt) g.parse(os.path.join(root_path, filename), format=fmt)
raw_rdf_graphs.append(g) raw_rdf_graphs.append(g)
return itertools.chain(*raw_rdf_graphs) return itertools.chain(*raw_rdf_graphs)
...@@ -175,11 +198,16 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -175,11 +198,16 @@ class RDFGraphDataset(DGLBuiltinDataset):
for i, (sbj, pred, obj) in enumerate(sorted_tuples): for i, (sbj, pred, obj) in enumerate(sorted_tuples):
if self.verbose and i % self._print_every == 0: if self.verbose and i % self._print_every == 0:
print('Processed %d tuples, found %d valid tuples.' % (i, len(src))) print(
"Processed %d tuples, found %d valid tuples."
% (i, len(src))
)
sbjent = self.parse_entity(sbj) sbjent = self.parse_entity(sbj)
rel = self.parse_relation(pred) rel = self.parse_relation(pred)
objent = self.parse_entity(obj) objent = self.parse_entity(obj)
processed = self.process_tuple((sbj, pred, obj), sbjent, rel, objent) processed = self.process_tuple(
(sbj, pred, obj), sbjent, rel, objent
)
if processed is None: if processed is None:
# ignored # ignored
continue continue
...@@ -189,7 +217,7 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -189,7 +217,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
relclsid = _get_id(rel_classes, rel.cls) relclsid = _get_id(rel_classes, rel.cls)
mg.add_edge(sbjent.cls, objent.cls, key=rel.cls) mg.add_edge(sbjent.cls, objent.cls, key=rel.cls)
if self._insert_reverse: if self._insert_reverse:
mg.add_edge(objent.cls, sbjent.cls, key='rev-%s' % rel.cls) mg.add_edge(objent.cls, sbjent.cls, key="rev-%s" % rel.cls)
# instance graph # instance graph
src_id = _get_id(entities, str(sbjent)) src_id = _get_id(entities, str(sbjent))
if len(entities) > len(ntid): # found new entity if len(entities) > len(ntid): # found new entity
...@@ -211,39 +239,47 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -211,39 +239,47 @@ class RDFGraphDataset(DGLBuiltinDataset):
# add reverse edge with reverse relation # add reverse edge with reverse relation
if self._insert_reverse: if self._insert_reverse:
if self.verbose: if self.verbose:
print('Adding reverse edges ...') print("Adding reverse edges ...")
newsrc = np.hstack([src, dst]) newsrc = np.hstack([src, dst])
newdst = np.hstack([dst, src]) newdst = np.hstack([dst, src])
src = newsrc src = newsrc
dst = newdst dst = newdst
etid = np.hstack([etid, etid + len(etypes)]) etid = np.hstack([etid, etid + len(etypes)])
etypes.extend(['rev-%s' % t for t in etypes]) etypes.extend(["rev-%s" % t for t in etypes])
hg = self.build_graph(mg, src, dst, ntid, etid, ntypes, etypes) hg = self.build_graph(mg, src, dst, ntid, etid, ntypes, etypes)
if self.verbose: if self.verbose:
print('Load training/validation/testing split ...') print("Load training/validation/testing split ...")
idmap = F.asnumpy(hg.nodes[self.predict_category].data[dgl.NID]) idmap = F.asnumpy(hg.nodes[self.predict_category].data[dgl.NID])
glb2lcl = {glbid : lclid for lclid, glbid in enumerate(idmap)} glb2lcl = {glbid: lclid for lclid, glbid in enumerate(idmap)}
def findidfn(ent): def findidfn(ent):
if ent not in entities: if ent not in entities:
return None return None
else: else:
return glb2lcl[entities[ent]] return glb2lcl[entities[ent]]
self._hg = hg
train_idx, test_idx, labels, num_classes = self.load_data_split(findidfn, root_path)
train_mask = idx2mask(train_idx, self._hg.number_of_nodes(self.predict_category)) self._hg = hg
test_mask = idx2mask(test_idx, self._hg.number_of_nodes(self.predict_category)) train_idx, test_idx, labels, num_classes = self.load_data_split(
labels = F.tensor(labels, F.data_type_dict['int64']) findidfn, root_path
)
train_mask = idx2mask(
train_idx, self._hg.number_of_nodes(self.predict_category)
)
test_mask = idx2mask(
test_idx, self._hg.number_of_nodes(self.predict_category)
)
labels = F.tensor(labels, F.data_type_dict["int64"])
train_mask = generate_mask_tensor(train_mask) train_mask = generate_mask_tensor(train_mask)
test_mask = generate_mask_tensor(test_mask) test_mask = generate_mask_tensor(test_mask)
self._hg.nodes[self.predict_category].data['train_mask'] = train_mask self._hg.nodes[self.predict_category].data["train_mask"] = train_mask
self._hg.nodes[self.predict_category].data['test_mask'] = test_mask self._hg.nodes[self.predict_category].data["test_mask"] = test_mask
# TODO(minjie): Deprecate 'labels', use 'label' for consistency. # TODO(minjie): Deprecate 'labels', use 'label' for consistency.
self._hg.nodes[self.predict_category].data['labels'] = labels self._hg.nodes[self.predict_category].data["labels"] = labels
self._hg.nodes[self.predict_category].data['label'] = labels self._hg.nodes[self.predict_category].data["label"] = labels
self._num_classes = num_classes self._num_classes = num_classes
def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes): def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):
...@@ -272,13 +308,13 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -272,13 +308,13 @@ class RDFGraphDataset(DGLBuiltinDataset):
""" """
# create homo graph # create homo graph
if self.verbose: if self.verbose:
print('Creating one whole graph ...') print("Creating one whole graph ...")
g = dgl.graph((src, dst)) g = dgl.graph((src, dst))
g.ndata[dgl.NTYPE] = F.tensor(ntid) g.ndata[dgl.NTYPE] = F.tensor(ntid)
g.edata[dgl.ETYPE] = F.tensor(etid) g.edata[dgl.ETYPE] = F.tensor(etid)
if self.verbose: if self.verbose:
print('Total #nodes:', g.number_of_nodes()) print("Total #nodes:", g.number_of_nodes())
print('Total #edges:', g.number_of_edges()) print("Total #edges:", g.number_of_edges())
# rename names such as 'type' so that they an be used as keys # rename names such as 'type' so that they an be used as keys
# to nn.ModuleDict # to nn.ModuleDict
...@@ -290,15 +326,12 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -290,15 +326,12 @@ class RDFGraphDataset(DGLBuiltinDataset):
# convert to heterograph # convert to heterograph
if self.verbose: if self.verbose:
print('Convert to heterograph ...') print("Convert to heterograph ...")
hg = dgl.to_heterogeneous(g, hg = dgl.to_heterogeneous(g, ntypes, etypes, metagraph=mg)
ntypes,
etypes,
metagraph=mg)
if self.verbose: if self.verbose:
print('#Node types:', len(hg.ntypes)) print("#Node types:", len(hg.ntypes))
print('#Canonical edge types:', len(hg.etypes)) print("#Canonical edge types:", len(hg.etypes))
print('#Unique edge type names:', len(set(hg.etypes))) print("#Unique edge type names:", len(set(hg.etypes)))
return hg return hg
def load_data_split(self, ent2id, root_path): def load_data_split(self, ent2id, root_path):
...@@ -323,13 +356,18 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -323,13 +356,18 @@ class RDFGraphDataset(DGLBuiltinDataset):
Number of classes Number of classes
""" """
label_dict = {} label_dict = {}
labels = np.zeros((self._hg.number_of_nodes(self.predict_category),)) - 1 labels = (
np.zeros((self._hg.number_of_nodes(self.predict_category),)) - 1
)
train_idx = self.parse_idx_file( train_idx = self.parse_idx_file(
os.path.join(root_path, 'trainingSet.tsv'), os.path.join(root_path, "trainingSet.tsv"),
ent2id, label_dict, labels) ent2id,
label_dict,
labels,
)
test_idx = self.parse_idx_file( test_idx = self.parse_idx_file(
os.path.join(root_path, 'testSet.tsv'), os.path.join(root_path, "testSet.tsv"), ent2id, label_dict, labels
ent2id, label_dict, labels) )
train_idx = np.array(train_idx) train_idx = np.array(train_idx)
test_idx = np.array(test_idx) test_idx = np.array(test_idx)
labels = np.array(labels) labels = np.array(labels)
...@@ -356,16 +394,19 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -356,16 +394,19 @@ class RDFGraphDataset(DGLBuiltinDataset):
Entity idss Entity idss
""" """
idx = [] idx = []
with open(filename, 'r') as f: with open(filename, "r") as f:
for i, line in enumerate(f): for i, line in enumerate(f):
if i == 0: if i == 0:
continue # first line is the header continue # first line is the header
sample, label = self.process_idx_file_line(line) sample, label = self.process_idx_file_line(line)
#person, _, label = line.strip().split('\t') # person, _, label = line.strip().split('\t')
ent = self.parse_entity(sample) ent = self.parse_entity(sample)
entid = ent2id(str(ent)) entid = ent2id(str(ent))
if entid is None: if entid is None:
print('Warning: entity "%s" does not have any valid links associated. Ignored.' % str(ent)) print(
'Warning: entity "%s" does not have any valid links associated. Ignored.'
% str(ent)
)
else: else:
idx.append(entid) idx.append(entid)
lblid = _get_id(label_dict, label) lblid = _get_id(label_dict, label)
...@@ -374,46 +415,44 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -374,46 +415,44 @@ class RDFGraphDataset(DGLBuiltinDataset):
def has_cache(self): def has_cache(self):
"""check if there is a processed data""" """check if there is a processed data"""
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.save_name + ".bin")
self.save_name + '.bin') info_path = os.path.join(self.save_path, self.save_name + ".pkl")
info_path = os.path.join(self.save_path, if os.path.exists(graph_path) and os.path.exists(info_path):
self.save_name + '.pkl')
if os.path.exists(graph_path) and \
os.path.exists(info_path):
return True return True
return False return False
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.save_name + ".bin")
self.save_name + '.bin') info_path = os.path.join(self.save_path, self.save_name + ".pkl")
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
save_graphs(str(graph_path), self._hg) save_graphs(str(graph_path), self._hg)
save_info(str(info_path), {'num_classes': self.num_classes, save_info(
'predict_category': self.predict_category}) str(info_path),
{
"num_classes": self.num_classes,
"predict_category": self.predict_category,
},
)
def load(self): def load(self):
"""load the graph list and the labels from disk""" """load the graph list and the labels from disk"""
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path, self.save_name + ".bin")
self.save_name + '.bin') info_path = os.path.join(self.save_path, self.save_name + ".pkl")
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
graphs, _ = load_graphs(str(graph_path)) graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path)) info = load_info(str(info_path))
self._num_classes = info['num_classes'] self._num_classes = info["num_classes"]
self._predict_category = info['predict_category'] self._predict_category = info["predict_category"]
self._hg = graphs[0] self._hg = graphs[0]
# For backward compatibility # For backward compatibility
if 'label' not in self._hg.nodes[self.predict_category].data: if "label" not in self._hg.nodes[self.predict_category].data:
self._hg.nodes[self.predict_category].data['label'] = \ self._hg.nodes[self.predict_category].data[
self._hg.nodes[self.predict_category].data['labels'] "label"
] = self._hg.nodes[self.predict_category].data["labels"]
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object"""
"""
g = self._hg g = self._hg
if self._transform is not None: if self._transform is not None:
g = self._transform(g) g = self._transform(g)
...@@ -425,7 +464,7 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -425,7 +464,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
@property @property
def save_name(self): def save_name(self):
return self.name + '_dgl_graph' return self.name + "_dgl_graph"
@property @property
def predict_category(self): def predict_category(self):
...@@ -504,6 +543,7 @@ class RDFGraphDataset(DGLBuiltinDataset): ...@@ -504,6 +543,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
""" """
pass pass
def _get_id(dict, key): def _get_id(dict, key):
id = dict.get(key, None) id = dict.get(key, None)
if id is None: if id is None:
...@@ -511,6 +551,7 @@ def _get_id(dict, key): ...@@ -511,6 +551,7 @@ def _get_id(dict, key):
dict[key] = id dict[key] = id
return id return id
class AIFBDataset(RDFGraphDataset): class AIFBDataset(RDFGraphDataset):
r"""AIFB dataset for node classification task r"""AIFB dataset for node classification task
...@@ -566,29 +607,40 @@ class AIFBDataset(RDFGraphDataset): ...@@ -566,29 +607,40 @@ class AIFBDataset(RDFGraphDataset):
>>> label = g.nodes[category].data['label'] >>> label = g.nodes[category].data['label']
""" """
entity_prefix = 'http://www.aifb.uni-karlsruhe.de/' entity_prefix = "http://www.aifb.uni-karlsruhe.de/"
relation_prefix = 'http://swrc.ontoware.org/' relation_prefix = "http://swrc.ontoware.org/"
def __init__(self, def __init__(
print_every=10000, self,
insert_reverse=True, print_every=10000,
raw_dir=None, insert_reverse=True,
force_reload=False, raw_dir=None,
verbose=True, force_reload=False,
transform=None): verbose=True,
transform=None,
):
import rdflib as rdf import rdflib as rdf
self.employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
self.affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation") self.employs = rdf.term.URIRef(
url = _get_dgl_url('dataset/rdf/aifb-hetero.zip') "http://swrc.ontoware.org/ontology#employs"
name = 'aifb-hetero' )
predict_category = 'Personen' self.affiliation = rdf.term.URIRef(
super(AIFBDataset, self).__init__(name, url, predict_category, "http://swrc.ontoware.org/ontology#affiliation"
print_every=print_every, )
insert_reverse=insert_reverse, url = _get_dgl_url("dataset/rdf/aifb-hetero.zip")
raw_dir=raw_dir, name = "aifb-hetero"
force_reload=force_reload, predict_category = "Personen"
verbose=verbose, super(AIFBDataset, self).__init__(
transform=transform) name,
url,
predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -621,13 +673,14 @@ class AIFBDataset(RDFGraphDataset): ...@@ -621,13 +673,14 @@ class AIFBDataset(RDFGraphDataset):
def parse_entity(self, term): def parse_entity(self, term):
import rdflib as rdf import rdflib as rdf
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
return Entity(e_id=str(term), cls="_Literal") return Entity(e_id=str(term), cls="_Literal")
if isinstance(term, rdf.BNode): if isinstance(term, rdf.BNode):
return None return None
entstr = str(term) entstr = str(term)
if entstr.startswith(self.entity_prefix): if entstr.startswith(self.entity_prefix):
sp = entstr.split('/') sp = entstr.split("/")
return Entity(e_id=sp[5], cls=sp[3]) return Entity(e_id=sp[5], cls=sp[3])
else: else:
return None return None
...@@ -637,9 +690,9 @@ class AIFBDataset(RDFGraphDataset): ...@@ -637,9 +690,9 @@ class AIFBDataset(RDFGraphDataset):
return None return None
relstr = str(term) relstr = str(term)
if relstr.startswith(self.relation_prefix): if relstr.startswith(self.relation_prefix):
return Relation(cls=relstr.split('/')[3]) return Relation(cls=relstr.split("/")[3])
else: else:
relstr = relstr.split('/')[-1] relstr = relstr.split("/")[-1]
return Relation(cls=relstr) return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj): def process_tuple(self, raw_tuple, sbj, rel, obj):
...@@ -648,9 +701,10 @@ class AIFBDataset(RDFGraphDataset): ...@@ -648,9 +701,10 @@ class AIFBDataset(RDFGraphDataset):
return (sbj, rel, obj) return (sbj, rel, obj)
def process_idx_file_line(self, line): def process_idx_file_line(self, line):
person, _, label = line.strip().split('\t') person, _, label = line.strip().split("\t")
return person, label return person, label
class MUTAGDataset(RDFGraphDataset): class MUTAGDataset(RDFGraphDataset):
r"""MUTAG dataset for node classification task r"""MUTAG dataset for node classification task
...@@ -707,32 +761,47 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -707,32 +761,47 @@ class MUTAGDataset(RDFGraphDataset):
d_entity = re.compile("d[0-9]") d_entity = re.compile("d[0-9]")
bond_entity = re.compile("bond[0-9]") bond_entity = re.compile("bond[0-9]")
entity_prefix = 'http://dl-learner.org/carcinogenesis#' entity_prefix = "http://dl-learner.org/carcinogenesis#"
relation_prefix = entity_prefix relation_prefix = entity_prefix
def __init__(self, def __init__(
print_every=10000, self,
insert_reverse=True, print_every=10000,
raw_dir=None, insert_reverse=True,
force_reload=False, raw_dir=None,
verbose=True, force_reload=False,
transform=None): verbose=True,
transform=None,
):
import rdflib as rdf import rdflib as rdf
self.is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
self.rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type") self.is_mutagenic = rdf.term.URIRef(
self.rdf_subclassof = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#subClassOf") "http://dl-learner.org/carcinogenesis#isMutagenic"
self.rdf_domain = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#domain") )
self.rdf_type = rdf.term.URIRef(
url = _get_dgl_url('dataset/rdf/mutag-hetero.zip') "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
name = 'mutag-hetero' )
predict_category = 'd' self.rdf_subclassof = rdf.term.URIRef(
super(MUTAGDataset, self).__init__(name, url, predict_category, "http://www.w3.org/2000/01/rdf-schema#subClassOf"
print_every=print_every, )
insert_reverse=insert_reverse, self.rdf_domain = rdf.term.URIRef(
raw_dir=raw_dir, "http://www.w3.org/2000/01/rdf-schema#domain"
force_reload=force_reload, )
verbose=verbose,
transform=transform) url = _get_dgl_url("dataset/rdf/mutag-hetero.zip")
name = "mutag-hetero"
predict_category = "d"
super(MUTAGDataset, self).__init__(
name,
url,
predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -765,17 +834,18 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -765,17 +834,18 @@ class MUTAGDataset(RDFGraphDataset):
def parse_entity(self, term): def parse_entity(self, term):
import rdflib as rdf import rdflib as rdf
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
return Entity(e_id=str(term), cls="_Literal") return Entity(e_id=str(term), cls="_Literal")
elif isinstance(term, rdf.BNode): elif isinstance(term, rdf.BNode):
return None return None
entstr = str(term) entstr = str(term)
if entstr.startswith(self.entity_prefix): if entstr.startswith(self.entity_prefix):
inst = entstr[len(self.entity_prefix):] inst = entstr[len(self.entity_prefix) :]
if self.d_entity.match(inst): if self.d_entity.match(inst):
cls = 'd' cls = "d"
elif self.bond_entity.match(inst): elif self.bond_entity.match(inst):
cls = 'bond' cls = "bond"
else: else:
cls = None cls = None
return Entity(e_id=inst, cls=cls) return Entity(e_id=inst, cls=cls)
...@@ -787,20 +857,20 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -787,20 +857,20 @@ class MUTAGDataset(RDFGraphDataset):
return None return None
relstr = str(term) relstr = str(term)
if relstr.startswith(self.relation_prefix): if relstr.startswith(self.relation_prefix):
cls = relstr[len(self.relation_prefix):] cls = relstr[len(self.relation_prefix) :]
return Relation(cls=cls) return Relation(cls=cls)
else: else:
relstr = relstr.split('/')[-1] relstr = relstr.split("/")[-1]
return Relation(cls=relstr) return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj): def process_tuple(self, raw_tuple, sbj, rel, obj):
if sbj is None or rel is None or obj is None: if sbj is None or rel is None or obj is None:
return None return None
if not raw_tuple[1].startswith('http://dl-learner.org/carcinogenesis#'): if not raw_tuple[1].startswith("http://dl-learner.org/carcinogenesis#"):
obj.cls = 'SCHEMA' obj.cls = "SCHEMA"
if sbj.cls is None: if sbj.cls is None:
sbj.cls = 'SCHEMA' sbj.cls = "SCHEMA"
if obj.cls is None: if obj.cls is None:
obj.cls = rel.cls obj.cls = rel.cls
...@@ -809,9 +879,10 @@ class MUTAGDataset(RDFGraphDataset): ...@@ -809,9 +879,10 @@ class MUTAGDataset(RDFGraphDataset):
return (sbj, rel, obj) return (sbj, rel, obj)
def process_idx_file_line(self, line): def process_idx_file_line(self, line):
bond, _, label = line.strip().split('\t') bond, _, label = line.strip().split("\t")
return bond, label return bond, label
class BGSDataset(RDFGraphDataset): class BGSDataset(RDFGraphDataset):
r"""BGS dataset for node classification task r"""BGS dataset for node classification task
...@@ -869,29 +940,38 @@ class BGSDataset(RDFGraphDataset): ...@@ -869,29 +940,38 @@ class BGSDataset(RDFGraphDataset):
>>> label = g.nodes[category].data['label'] >>> label = g.nodes[category].data['label']
""" """
entity_prefix = 'http://data.bgs.ac.uk/' entity_prefix = "http://data.bgs.ac.uk/"
status_prefix = 'http://data.bgs.ac.uk/ref/CurrentStatus' status_prefix = "http://data.bgs.ac.uk/ref/CurrentStatus"
relation_prefix = 'http://data.bgs.ac.uk/ref' relation_prefix = "http://data.bgs.ac.uk/ref"
def __init__(self, def __init__(
print_every=10000, self,
insert_reverse=True, print_every=10000,
raw_dir=None, insert_reverse=True,
force_reload=False, raw_dir=None,
verbose=True, force_reload=False,
transform=None): verbose=True,
transform=None,
):
import rdflib as rdf import rdflib as rdf
url = _get_dgl_url('dataset/rdf/bgs-hetero.zip')
name = 'bgs-hetero' url = _get_dgl_url("dataset/rdf/bgs-hetero.zip")
predict_category = 'Lexicon/NamedRockUnit' name = "bgs-hetero"
self.lith = rdf.term.URIRef("http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis") predict_category = "Lexicon/NamedRockUnit"
super(BGSDataset, self).__init__(name, url, predict_category, self.lith = rdf.term.URIRef(
print_every=print_every, "http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis"
insert_reverse=insert_reverse, )
raw_dir=raw_dir, super(BGSDataset, self).__init__(
force_reload=force_reload, name,
verbose=verbose, url,
transform=transform) predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -924,6 +1004,7 @@ class BGSDataset(RDFGraphDataset): ...@@ -924,6 +1004,7 @@ class BGSDataset(RDFGraphDataset):
def parse_entity(self, term): def parse_entity(self, term):
import rdflib as rdf import rdflib as rdf
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
return None return None
elif isinstance(term, rdf.BNode): elif isinstance(term, rdf.BNode):
...@@ -932,11 +1013,11 @@ class BGSDataset(RDFGraphDataset): ...@@ -932,11 +1013,11 @@ class BGSDataset(RDFGraphDataset):
if entstr.startswith(self.status_prefix): if entstr.startswith(self.status_prefix):
return None return None
if entstr.startswith(self.entity_prefix): if entstr.startswith(self.entity_prefix):
sp = entstr.split('/') sp = entstr.split("/")
if len(sp) != 7: if len(sp) != 7:
return None return None
# instance # instance
cls = '%s/%s' % (sp[4], sp[5]) cls = "%s/%s" % (sp[4], sp[5])
inst = sp[6] inst = sp[6]
return Entity(e_id=inst, cls=cls) return Entity(e_id=inst, cls=cls)
else: else:
...@@ -947,14 +1028,14 @@ class BGSDataset(RDFGraphDataset): ...@@ -947,14 +1028,14 @@ class BGSDataset(RDFGraphDataset):
return None return None
relstr = str(term) relstr = str(term)
if relstr.startswith(self.relation_prefix): if relstr.startswith(self.relation_prefix):
sp = relstr.split('/') sp = relstr.split("/")
if len(sp) < 6: if len(sp) < 6:
return None return None
assert len(sp) == 6, relstr assert len(sp) == 6, relstr
cls = '%s/%s' % (sp[4], sp[5]) cls = "%s/%s" % (sp[4], sp[5])
return Relation(cls=cls) return Relation(cls=cls)
else: else:
relstr = relstr.replace('.', '_') relstr = relstr.replace(".", "_")
return Relation(cls=relstr) return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj): def process_tuple(self, raw_tuple, sbj, rel, obj):
...@@ -963,9 +1044,10 @@ class BGSDataset(RDFGraphDataset): ...@@ -963,9 +1044,10 @@ class BGSDataset(RDFGraphDataset):
return (sbj, rel, obj) return (sbj, rel, obj)
def process_idx_file_line(self, line): def process_idx_file_line(self, line):
_, rock, label = line.strip().split('\t') _, rock, label = line.strip().split("\t")
return rock, label return rock, label
class AMDataset(RDFGraphDataset): class AMDataset(RDFGraphDataset):
"""AM dataset. for node classification task """AM dataset. for node classification task
...@@ -1025,29 +1107,40 @@ class AMDataset(RDFGraphDataset): ...@@ -1025,29 +1107,40 @@ class AMDataset(RDFGraphDataset):
>>> label = g.nodes[category].data['label'] >>> label = g.nodes[category].data['label']
""" """
entity_prefix = 'http://purl.org/collections/nl/am/' entity_prefix = "http://purl.org/collections/nl/am/"
relation_prefix = entity_prefix relation_prefix = entity_prefix
def __init__(self, def __init__(
print_every=10000, self,
insert_reverse=True, print_every=10000,
raw_dir=None, insert_reverse=True,
force_reload=False, raw_dir=None,
verbose=True, force_reload=False,
transform=None): verbose=True,
transform=None,
):
import rdflib as rdf import rdflib as rdf
self.objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
self.material = rdf.term.URIRef("http://purl.org/collections/nl/am/material") self.objectCategory = rdf.term.URIRef(
url = _get_dgl_url('dataset/rdf/am-hetero.zip') "http://purl.org/collections/nl/am/objectCategory"
name = 'am-hetero' )
predict_category = 'proxy' self.material = rdf.term.URIRef(
super(AMDataset, self).__init__(name, url, predict_category, "http://purl.org/collections/nl/am/material"
print_every=print_every, )
insert_reverse=insert_reverse, url = _get_dgl_url("dataset/rdf/am-hetero.zip")
raw_dir=raw_dir, name = "am-hetero"
force_reload=force_reload, predict_category = "proxy"
verbose=verbose, super(AMDataset, self).__init__(
transform=transform) name,
url,
predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Gets the graph object r"""Gets the graph object
...@@ -1080,20 +1173,21 @@ class AMDataset(RDFGraphDataset): ...@@ -1080,20 +1173,21 @@ class AMDataset(RDFGraphDataset):
def parse_entity(self, term): def parse_entity(self, term):
import rdflib as rdf import rdflib as rdf
if isinstance(term, rdf.Literal): if isinstance(term, rdf.Literal):
return None return None
elif isinstance(term, rdf.BNode): elif isinstance(term, rdf.BNode):
return Entity(e_id=str(term), cls='_BNode') return Entity(e_id=str(term), cls="_BNode")
entstr = str(term) entstr = str(term)
if entstr.startswith(self.entity_prefix): if entstr.startswith(self.entity_prefix):
sp = entstr.split('/') sp = entstr.split("/")
assert len(sp) == 7, entstr assert len(sp) == 7, entstr
spp = sp[6].split('-') spp = sp[6].split("-")
if len(spp) == 2: if len(spp) == 2:
# instance # instance
cls, inst = spp cls, inst = spp
else: else:
cls = 'TYPE' cls = "TYPE"
inst = spp inst = spp
return Entity(e_id=inst, cls=cls) return Entity(e_id=inst, cls=cls)
else: else:
...@@ -1104,12 +1198,12 @@ class AMDataset(RDFGraphDataset): ...@@ -1104,12 +1198,12 @@ class AMDataset(RDFGraphDataset):
return None return None
relstr = str(term) relstr = str(term)
if relstr.startswith(self.relation_prefix): if relstr.startswith(self.relation_prefix):
sp = relstr.split('/') sp = relstr.split("/")
assert len(sp) == 7, relstr assert len(sp) == 7, relstr
cls = sp[6] cls = sp[6]
return Relation(cls=cls) return Relation(cls=cls)
else: else:
relstr = relstr.replace('.', '_') relstr = relstr.replace(".", "_")
return Relation(cls=relstr) return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj): def process_tuple(self, raw_tuple, sbj, rel, obj):
...@@ -1118,5 +1212,5 @@ class AMDataset(RDFGraphDataset): ...@@ -1118,5 +1212,5 @@ class AMDataset(RDFGraphDataset):
return (sbj, rel, obj) return (sbj, rel, obj)
def process_idx_file_line(self, line): def process_idx_file_line(self, line):
proxy, _, label = line.strip().split('\t') proxy, _, label = line.strip().split("\t")
return proxy, label return proxy, label
"""Dataset for stochastic block model.""" """Dataset for stochastic block model."""
import math import math
import random
import os import os
import random
import numpy as np import numpy as np
import numpy.random as npr import numpy.random as npr
import scipy as sp import scipy as sp
from .dgl_dataset import DGLDataset
from ..convert import from_scipy
from .. import batch from .. import batch
from .utils import save_info, save_graphs, load_info, load_graphs from ..convert import from_scipy
from .dgl_dataset import DGLDataset
from .utils import load_graphs, load_info, save_graphs, save_info
def sbm(n_blocks, block_size, p, q, rng=None): def sbm(n_blocks, block_size, p, q, rng=None):
""" (Symmetric) Stochastic Block Model """(Symmetric) Stochastic Block Model
Parameters Parameters
---------- ----------
...@@ -44,20 +44,27 @@ def sbm(n_blocks, block_size, p, q, rng=None): ...@@ -44,20 +44,27 @@ def sbm(n_blocks, block_size, p, q, rng=None):
for i in range(n_blocks): for i in range(n_blocks):
for j in range(i, n_blocks): for j in range(i, n_blocks):
density = p if i == j else q density = p if i == j else q
block = sp.sparse.random(block_size, block_size, density, block = sp.sparse.random(
random_state=rng, data_rvs=lambda n: np.ones(n)) block_size,
block_size,
density,
random_state=rng,
data_rvs=lambda n: np.ones(n),
)
rows.append(block.row + i * block_size) rows.append(block.row + i * block_size)
cols.append(block.col + j * block_size) cols.append(block.col + j * block_size)
rows = np.hstack(rows) rows = np.hstack(rows)
cols = np.hstack(cols) cols = np.hstack(cols)
a = sp.sparse.coo_matrix((np.ones(rows.shape[0]), (rows, cols)), shape=(n, n)) a = sp.sparse.coo_matrix(
(np.ones(rows.shape[0]), (rows, cols)), shape=(n, n)
)
adj = sp.sparse.triu(a) + sp.sparse.triu(a, 1).transpose() adj = sp.sparse.triu(a) + sp.sparse.triu(a, 1).transpose()
return adj return adj
class SBMMixtureDataset(DGLDataset): class SBMMixtureDataset(DGLDataset):
r""" Symmetric Stochastic Block Model Mixture r"""Symmetric Stochastic Block Model Mixture
Reference: Appendix C of `Supervised Community Detection with Hierarchical Graph Neural Networks <https://arxiv.org/abs/1705.08415>`_ Reference: Appendix C of `Supervised Community Detection with Hierarchical Graph Neural Networks <https://arxiv.org/abs/1705.08415>`_
...@@ -95,14 +102,17 @@ class SBMMixtureDataset(DGLDataset): ...@@ -95,14 +102,17 @@ class SBMMixtureDataset(DGLDataset):
>>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader: >>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader:
... # your code here ... # your code here
""" """
def __init__(self,
n_graphs, def __init__(
n_nodes, self,
n_communities, n_graphs,
k=2, n_nodes,
avg_deg=3, n_communities,
pq='Appendix_C', k=2,
rng=None): avg_deg=3,
pq="Appendix_C",
rng=None,
):
self._n_graphs = n_graphs self._n_graphs = n_graphs
self._n_nodes = n_nodes self._n_nodes = n_nodes
self._n_communities = n_communities self._n_communities = n_communities
...@@ -112,60 +122,92 @@ class SBMMixtureDataset(DGLDataset): ...@@ -112,60 +122,92 @@ class SBMMixtureDataset(DGLDataset):
self._avg_deg = avg_deg self._avg_deg = avg_deg
self._pq = pq self._pq = pq
self._rng = rng self._rng = rng
super(SBMMixtureDataset, self).__init__(name='sbmmixture', super(SBMMixtureDataset, self).__init__(
hash_key=(n_graphs, n_nodes, n_communities, k, avg_deg, pq, rng)) name="sbmmixture",
hash_key=(n_graphs, n_nodes, n_communities, k, avg_deg, pq, rng),
)
def process(self): def process(self):
pq = self._pq pq = self._pq
if type(pq) is list: if type(pq) is list:
assert len(pq) == self._n_graphs assert len(pq) == self._n_graphs
elif type(pq) is str: elif type(pq) is str:
generator = {'Appendix_C': self._appendix_c}[pq] generator = {"Appendix_C": self._appendix_c}[pq]
pq = [generator() for _ in range(self._n_graphs)] pq = [generator() for _ in range(self._n_graphs)]
else: else:
raise RuntimeError() raise RuntimeError()
self._graphs = [from_scipy(sbm(self._n_communities, self._block_size, *x)) for x in pq] self._graphs = [
self._line_graphs = [g.line_graph(backtracking=False) for g in self._graphs] from_scipy(sbm(self._n_communities, self._block_size, *x))
for x in pq
]
self._line_graphs = [
g.line_graph(backtracking=False) for g in self._graphs
]
in_degrees = lambda g: g.in_degrees().float() in_degrees = lambda g: g.in_degrees().float()
self._graph_degrees = [in_degrees(g) for g in self._graphs] self._graph_degrees = [in_degrees(g) for g in self._graphs]
self._line_graph_degrees = [in_degrees(lg) for lg in self._line_graphs] self._line_graph_degrees = [in_degrees(lg) for lg in self._line_graphs]
self._pm_pds = list(zip(*[g.edges() for g in self._graphs]))[0] self._pm_pds = list(zip(*[g.edges() for g in self._graphs]))[0]
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash)) graph_path = os.path.join(
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash)) self.save_path, "graphs_{}.bin".format(self.hash)
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash)) )
return os.path.exists(graph_path) and \ line_graph_path = os.path.join(
os.path.exists(line_graph_path) and \ self.save_path, "line_graphs_{}.bin".format(self.hash)
os.path.exists(info_path) )
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
return (
os.path.exists(graph_path)
and os.path.exists(line_graph_path)
and os.path.exists(info_path)
)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash)) graph_path = os.path.join(
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash)) self.save_path, "graphs_{}.bin".format(self.hash)
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash)) )
line_graph_path = os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash)
)
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
save_graphs(graph_path, self._graphs) save_graphs(graph_path, self._graphs)
save_graphs(line_graph_path, self._line_graphs) save_graphs(line_graph_path, self._line_graphs)
save_info(info_path, {'graph_degree': self._graph_degrees, save_info(
'line_graph_degree': self._line_graph_degrees, info_path,
'pm_pds': self._pm_pds}) {
"graph_degree": self._graph_degrees,
"line_graph_degree": self._line_graph_degrees,
"pm_pds": self._pm_pds,
},
)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash)) graph_path = os.path.join(
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash)) self.save_path, "graphs_{}.bin".format(self.hash)
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash)) )
line_graph_path = os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash)
)
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
self._graphs, _ = load_graphs(graph_path) self._graphs, _ = load_graphs(graph_path)
self._line_graphs, _ = load_graphs(line_graph_path) self._line_graphs, _ = load_graphs(line_graph_path)
info = load_info(info_path) info = load_info(info_path)
self._graph_degrees = info['graph_degree'] self._graph_degrees = info["graph_degree"]
self._line_graph_degrees = info['line_graph_degree'] self._line_graph_degrees = info["line_graph_degree"]
self._pm_pds = info['pm_pds'] self._pm_pds = info["pm_pds"]
def __len__(self): def __len__(self):
r"""Number of graphs in the dataset.""" r"""Number of graphs in the dataset."""
return len(self._graphs) return len(self._graphs)
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get one example by index r"""Get one example by index
Parameters Parameters
---------- ----------
...@@ -185,8 +227,13 @@ class SBMMixtureDataset(DGLDataset): ...@@ -185,8 +227,13 @@ class SBMMixtureDataset(DGLDataset):
pm_pd: numpy.ndarray pm_pd: numpy.ndarray
Edge indicator matrices Pm and Pd Edge indicator matrices Pm and Pd
""" """
return self._graphs[idx], self._line_graphs[idx], \ return (
self._graph_degrees[idx], self._line_graph_degrees[idx], self._pm_pds[idx] self._graphs[idx],
self._line_graphs[idx],
self._graph_degrees[idx],
self._line_graph_degrees[idx],
self._pm_pds[idx],
)
def _appendix_c(self): def _appendix_c(self):
q = npr.uniform(0, self._avg_deg - math.sqrt(self._avg_deg)) q = npr.uniform(0, self._avg_deg - math.sqrt(self._avg_deg))
...@@ -197,7 +244,7 @@ class SBMMixtureDataset(DGLDataset): ...@@ -197,7 +244,7 @@ class SBMMixtureDataset(DGLDataset):
return q, p return q, p
def collate_fn(self, x): def collate_fn(self, x):
r""" The `collate` function for dataloader r"""The `collate` function for dataloader
Parameters Parameters
---------- ----------
...@@ -233,7 +280,9 @@ class SBMMixtureDataset(DGLDataset): ...@@ -233,7 +280,9 @@ class SBMMixtureDataset(DGLDataset):
lg_batch = batch.batch(lg) lg_batch = batch.batch(lg)
degg_batch = np.concatenate(deg_g, axis=0) degg_batch = np.concatenate(deg_g, axis=0)
deglg_batch = np.concatenate(deg_lg, axis=0) deglg_batch = np.concatenate(deg_lg, axis=0)
pm_pd_batch = np.concatenate([x + i * self._n_nodes for i, x in enumerate(pm_pd)], axis=0) pm_pd_batch = np.concatenate(
[x + i * self._n_nodes for i, x in enumerate(pm_pd)], axis=0
)
return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch
......
"""Synthetic graph datasets.""" """Synthetic graph datasets."""
import math import math
import networkx as nx
import numpy as np
import os import os
import pickle import pickle
import random import random
from .dgl_dataset import DGLBuiltinDataset import networkx as nx
from .utils import save_graphs, load_graphs, _get_dgl_url, download import numpy as np
from .. import backend as F from .. import backend as F
from ..batch import batch from ..batch import batch
from ..convert import graph from ..convert import graph
from ..transforms import reorder_graph from ..transforms import reorder_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, download, load_graphs, save_graphs
class BAShapeDataset(DGLBuiltinDataset): class BAShapeDataset(DGLBuiltinDataset):
r"""BA-SHAPES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks r"""BA-SHAPES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
...@@ -70,30 +72,37 @@ class BAShapeDataset(DGLBuiltinDataset): ...@@ -70,30 +72,37 @@ class BAShapeDataset(DGLBuiltinDataset):
>>> label = g.ndata['label'] >>> label = g.ndata['label']
>>> feat = g.ndata['feat'] >>> feat = g.ndata['feat']
""" """
def __init__(self,
num_base_nodes=300, def __init__(
num_base_edges_per_node=5, self,
num_motifs=80, num_base_nodes=300,
perturb_ratio=0.01, num_base_edges_per_node=5,
seed=None, num_motifs=80,
raw_dir=None, perturb_ratio=0.01,
force_reload=False, seed=None,
verbose=True, raw_dir=None,
transform=None): force_reload=False,
verbose=True,
transform=None,
):
self.num_base_nodes = num_base_nodes self.num_base_nodes = num_base_nodes
self.num_base_edges_per_node = num_base_edges_per_node self.num_base_edges_per_node = num_base_edges_per_node
self.num_motifs = num_motifs self.num_motifs = num_motifs
self.perturb_ratio = perturb_ratio self.perturb_ratio = perturb_ratio
self.seed = seed self.seed = seed
super(BAShapeDataset, self).__init__(name='BA-SHAPES', super(BAShapeDataset, self).__init__(
url=None, name="BA-SHAPES",
raw_dir=raw_dir, url=None,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
g = nx.barabasi_albert_graph(self.num_base_nodes, self.num_base_edges_per_node, self.seed) g = nx.barabasi_albert_graph(
self.num_base_nodes, self.num_base_edges_per_node, self.seed
)
edges = list(g.edges()) edges = list(g.edges())
src, dst = map(list, zip(*edges)) src, dst = map(list, zip(*edges))
n = self.num_base_nodes n = self.num_base_nodes
...@@ -111,7 +120,7 @@ class BAShapeDataset(DGLBuiltinDataset): ...@@ -111,7 +120,7 @@ class BAShapeDataset(DGLBuiltinDataset):
(n + 2, n + 3), (n + 2, n + 3),
(n + 3, n), (n + 3, n),
(n + 4, n), (n + 4, n),
(n + 4, n + 1) (n + 4, n + 1),
] ]
motif_src, motif_dst = map(list, zip(*motif_edges)) motif_src, motif_dst = map(list, zip(*motif_edges))
src.extend(motif_src) src.extend(motif_src)
...@@ -132,8 +141,9 @@ class BAShapeDataset(DGLBuiltinDataset): ...@@ -132,8 +141,9 @@ class BAShapeDataset(DGLBuiltinDataset):
# Perturb the graph by adding non-self-loop random edges # Perturb the graph by adding non-self-loop random edges
num_real_edges = g.num_edges() num_real_edges = g.num_edges()
max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges
assert self.perturb_ratio <= max_ratio, \ assert (
'perturb_ratio cannot exceed {:.4f}'.format(max_ratio) self.perturb_ratio <= max_ratio
), "perturb_ratio cannot exceed {:.4f}".format(max_ratio)
num_random_edges = int(num_real_edges * self.perturb_ratio) num_random_edges = int(num_real_edges * self.perturb_ratio)
if self.seed is not None: if self.seed is not None:
...@@ -146,14 +156,20 @@ class BAShapeDataset(DGLBuiltinDataset): ...@@ -146,14 +156,20 @@ class BAShapeDataset(DGLBuiltinDataset):
break break
g.add_edges(u, v) g.add_edges(u, v)
g.ndata['label'] = F.tensor(node_labels, F.int64) g.ndata["label"] = F.tensor(node_labels, F.int64)
g.ndata['feat'] = F.ones((n, 1), F.float32, F.cpu()) g.ndata["feat"] = F.ones((n, 1), F.float32, F.cpu())
self._graph = reorder_graph( self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
@property @property
def graph_path(self): def graph_path(self):
return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name)) return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.name)
)
def save(self): def save(self):
save_graphs(str(self.graph_path), self._graph) save_graphs(str(self.graph_path), self._graph)
...@@ -179,6 +195,7 @@ class BAShapeDataset(DGLBuiltinDataset): ...@@ -179,6 +195,7 @@ class BAShapeDataset(DGLBuiltinDataset):
def num_classes(self): def num_classes(self):
return 4 return 4
class BACommunityDataset(DGLBuiltinDataset): class BACommunityDataset(DGLBuiltinDataset):
r"""BA-COMMUNITY dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks r"""BA-COMMUNITY dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
<https://arxiv.org/abs/1903.03894>`__ <https://arxiv.org/abs/1903.03894>`__
...@@ -241,29 +258,34 @@ class BACommunityDataset(DGLBuiltinDataset): ...@@ -241,29 +258,34 @@ class BACommunityDataset(DGLBuiltinDataset):
>>> label = g.ndata['label'] >>> label = g.ndata['label']
>>> feat = g.ndata['feat'] >>> feat = g.ndata['feat']
""" """
def __init__(self,
num_base_nodes=300, def __init__(
num_base_edges_per_node=4, self,
num_motifs=80, num_base_nodes=300,
perturb_ratio=0.01, num_base_edges_per_node=4,
num_inter_edges=350, num_motifs=80,
seed=None, perturb_ratio=0.01,
raw_dir=None, num_inter_edges=350,
force_reload=False, seed=None,
verbose=True, raw_dir=None,
transform=None): force_reload=False,
verbose=True,
transform=None,
):
self.num_base_nodes = num_base_nodes self.num_base_nodes = num_base_nodes
self.num_base_edges_per_node = num_base_edges_per_node self.num_base_edges_per_node = num_base_edges_per_node
self.num_motifs = num_motifs self.num_motifs = num_motifs
self.perturb_ratio = perturb_ratio self.perturb_ratio = perturb_ratio
self.num_inter_edges = num_inter_edges self.num_inter_edges = num_inter_edges
self.seed = seed self.seed = seed
super(BACommunityDataset, self).__init__(name='BA-COMMUNITY', super(BACommunityDataset, self).__init__(
url=None, name="BA-COMMUNITY",
raw_dir=raw_dir, url=None,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
if self.seed is not None: if self.seed is not None:
...@@ -271,47 +293,65 @@ class BACommunityDataset(DGLBuiltinDataset): ...@@ -271,47 +293,65 @@ class BACommunityDataset(DGLBuiltinDataset):
np.random.seed(self.seed) np.random.seed(self.seed)
# Construct two BA-SHAPES graphs # Construct two BA-SHAPES graphs
g1 = BAShapeDataset(self.num_base_nodes, g1 = BAShapeDataset(
self.num_base_edges_per_node, self.num_base_nodes,
self.num_motifs, self.num_base_edges_per_node,
self.perturb_ratio, self.num_motifs,
force_reload=True, self.perturb_ratio,
verbose=False)[0] force_reload=True,
g2 = BAShapeDataset(self.num_base_nodes, verbose=False,
self.num_base_edges_per_node, )[0]
self.num_motifs, g2 = BAShapeDataset(
self.perturb_ratio, self.num_base_nodes,
force_reload=True, self.num_base_edges_per_node,
verbose=False)[0] self.num_motifs,
self.perturb_ratio,
force_reload=True,
verbose=False,
)[0]
# Join them and randomly add edges between them # Join them and randomly add edges between them
g = batch([g1, g2]) g = batch([g1, g2])
num_nodes = g.num_nodes() // 2 num_nodes = g.num_nodes() // 2
src = np.random.randint(0, num_nodes, (self.num_inter_edges,)) src = np.random.randint(0, num_nodes, (self.num_inter_edges,))
dst = np.random.randint(num_nodes, 2 * num_nodes, (self.num_inter_edges,)) dst = np.random.randint(
num_nodes, 2 * num_nodes, (self.num_inter_edges,)
)
src = F.astype(F.zerocopy_from_numpy(src), g.idtype) src = F.astype(F.zerocopy_from_numpy(src), g.idtype)
dst = F.astype(F.zerocopy_from_numpy(dst), g.idtype) dst = F.astype(F.zerocopy_from_numpy(dst), g.idtype)
g.add_edges(src, dst) g.add_edges(src, dst)
g.ndata['label'] = F.cat([g1.ndata['label'], g2.ndata['label'] + 4], dim=0) g.ndata["label"] = F.cat(
[g1.ndata["label"], g2.ndata["label"] + 4], dim=0
)
# feature generation # feature generation
random_mu = [0.0] * 8 random_mu = [0.0] * 8
random_sigma = [1.0] * 8 random_sigma = [1.0] * 8
mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma) mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array(
[0.5] * 2 + random_sigma
)
feat1 = np.random.multivariate_normal(mu_1, np.diag(sigma_1), num_nodes) feat1 = np.random.multivariate_normal(mu_1, np.diag(sigma_1), num_nodes)
mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma) mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array(
[0.5] * 2 + random_sigma
)
feat2 = np.random.multivariate_normal(mu_2, np.diag(sigma_2), num_nodes) feat2 = np.random.multivariate_normal(mu_2, np.diag(sigma_2), num_nodes)
feat = np.concatenate([feat1, feat2]) feat = np.concatenate([feat1, feat2])
g.ndata['feat'] = F.zerocopy_from_numpy(feat) g.ndata["feat"] = F.zerocopy_from_numpy(feat)
self._graph = reorder_graph( self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
@property @property
def graph_path(self): def graph_path(self):
return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name)) return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.name)
)
def save(self): def save(self):
save_graphs(str(self.graph_path), self._graph) save_graphs(str(self.graph_path), self._graph)
...@@ -337,6 +377,7 @@ class BACommunityDataset(DGLBuiltinDataset): ...@@ -337,6 +377,7 @@ class BACommunityDataset(DGLBuiltinDataset):
def num_classes(self): def num_classes(self):
return 8 return 8
class TreeCycleDataset(DGLBuiltinDataset): class TreeCycleDataset(DGLBuiltinDataset):
r"""TREE-CYCLES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks r"""TREE-CYCLES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
<https://arxiv.org/abs/1903.03894>`__ <https://arxiv.org/abs/1903.03894>`__
...@@ -392,27 +433,32 @@ class TreeCycleDataset(DGLBuiltinDataset): ...@@ -392,27 +433,32 @@ class TreeCycleDataset(DGLBuiltinDataset):
>>> label = g.ndata['label'] >>> label = g.ndata['label']
>>> feat = g.ndata['feat'] >>> feat = g.ndata['feat']
""" """
def __init__(self,
tree_height=8, def __init__(
num_motifs=60, self,
cycle_size=6, tree_height=8,
perturb_ratio=0.01, num_motifs=60,
seed=None, cycle_size=6,
raw_dir=None, perturb_ratio=0.01,
force_reload=False, seed=None,
verbose=True, raw_dir=None,
transform=None): force_reload=False,
verbose=True,
transform=None,
):
self.tree_height = tree_height self.tree_height = tree_height
self.num_motifs = num_motifs self.num_motifs = num_motifs
self.cycle_size = cycle_size self.cycle_size = cycle_size
self.perturb_ratio = perturb_ratio self.perturb_ratio = perturb_ratio
self.seed = seed self.seed = seed
super(TreeCycleDataset, self).__init__(name='TREE-CYCLES', super(TreeCycleDataset, self).__init__(
url=None, name="TREE-CYCLES",
raw_dir=raw_dir, url=None,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
if self.seed is not None: if self.seed is not None:
...@@ -457,8 +503,9 @@ class TreeCycleDataset(DGLBuiltinDataset): ...@@ -457,8 +503,9 @@ class TreeCycleDataset(DGLBuiltinDataset):
# Perturb the graph by adding non-self-loop random edges # Perturb the graph by adding non-self-loop random edges
num_real_edges = g.num_edges() num_real_edges = g.num_edges()
max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges
assert self.perturb_ratio <= max_ratio, \ assert (
'perturb_ratio cannot exceed {:.4f}'.format(max_ratio) self.perturb_ratio <= max_ratio
), "perturb_ratio cannot exceed {:.4f}".format(max_ratio)
num_random_edges = int(num_real_edges * self.perturb_ratio) num_random_edges = int(num_real_edges * self.perturb_ratio)
for _ in range(num_random_edges): for _ in range(num_random_edges):
...@@ -469,14 +516,20 @@ class TreeCycleDataset(DGLBuiltinDataset): ...@@ -469,14 +516,20 @@ class TreeCycleDataset(DGLBuiltinDataset):
break break
g.add_edges(u, v) g.add_edges(u, v)
g.ndata['label'] = F.tensor(node_labels, F.int64) g.ndata["label"] = F.tensor(node_labels, F.int64)
g.ndata['feat'] = F.ones((n, 1), F.float32, F.cpu()) g.ndata["feat"] = F.ones((n, 1), F.float32, F.cpu())
self._graph = reorder_graph( self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
@property @property
def graph_path(self): def graph_path(self):
return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name)) return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.name)
)
def save(self): def save(self):
save_graphs(str(self.graph_path), self._graph) save_graphs(str(self.graph_path), self._graph)
...@@ -502,6 +555,7 @@ class TreeCycleDataset(DGLBuiltinDataset): ...@@ -502,6 +555,7 @@ class TreeCycleDataset(DGLBuiltinDataset):
def num_classes(self): def num_classes(self):
return 2 return 2
class TreeGridDataset(DGLBuiltinDataset): class TreeGridDataset(DGLBuiltinDataset):
r"""TREE-GRIDS dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks r"""TREE-GRIDS dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
<https://arxiv.org/abs/1903.03894>`__ <https://arxiv.org/abs/1903.03894>`__
...@@ -557,27 +611,32 @@ class TreeGridDataset(DGLBuiltinDataset): ...@@ -557,27 +611,32 @@ class TreeGridDataset(DGLBuiltinDataset):
>>> label = g.ndata['label'] >>> label = g.ndata['label']
>>> feat = g.ndata['feat'] >>> feat = g.ndata['feat']
""" """
def __init__(self,
tree_height=8, def __init__(
num_motifs=80, self,
grid_size=3, tree_height=8,
perturb_ratio=0.1, num_motifs=80,
seed=None, grid_size=3,
raw_dir=None, perturb_ratio=0.1,
force_reload=False, seed=None,
verbose=True, raw_dir=None,
transform=None): force_reload=False,
verbose=True,
transform=None,
):
self.tree_height = tree_height self.tree_height = tree_height
self.num_motifs = num_motifs self.num_motifs = num_motifs
self.grid_size = grid_size self.grid_size = grid_size
self.perturb_ratio = perturb_ratio self.perturb_ratio = perturb_ratio
self.seed = seed self.seed = seed
super(TreeGridDataset, self).__init__(name='TREE-GRIDS', super(TreeGridDataset, self).__init__(
url=None, name="TREE-GRIDS",
raw_dir=raw_dir, url=None,
force_reload=force_reload, raw_dir=raw_dir,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
if self.seed is not None: if self.seed is not None:
...@@ -619,8 +678,9 @@ class TreeGridDataset(DGLBuiltinDataset): ...@@ -619,8 +678,9 @@ class TreeGridDataset(DGLBuiltinDataset):
# Perturb the graph by adding non-self-loop random edges # Perturb the graph by adding non-self-loop random edges
num_real_edges = g.num_edges() num_real_edges = g.num_edges()
max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges
assert self.perturb_ratio <= max_ratio, \ assert (
'perturb_ratio cannot exceed {:.4f}'.format(max_ratio) self.perturb_ratio <= max_ratio
), "perturb_ratio cannot exceed {:.4f}".format(max_ratio)
num_random_edges = int(num_real_edges * self.perturb_ratio) num_random_edges = int(num_real_edges * self.perturb_ratio)
for _ in range(num_random_edges): for _ in range(num_random_edges):
...@@ -631,14 +691,20 @@ class TreeGridDataset(DGLBuiltinDataset): ...@@ -631,14 +691,20 @@ class TreeGridDataset(DGLBuiltinDataset):
break break
g.add_edges(u, v) g.add_edges(u, v)
g.ndata['label'] = F.tensor(node_labels, F.int64) g.ndata["label"] = F.tensor(node_labels, F.int64)
g.ndata['feat'] = F.ones((n, 1), F.float32, F.cpu()) g.ndata["feat"] = F.ones((n, 1), F.float32, F.cpu())
self._graph = reorder_graph( self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
@property @property
def graph_path(self): def graph_path(self):
return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name)) return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.name)
)
def save(self): def save(self):
save_graphs(str(self.graph_path), self._graph) save_graphs(str(self.graph_path), self._graph)
...@@ -664,6 +730,7 @@ class TreeGridDataset(DGLBuiltinDataset): ...@@ -664,6 +730,7 @@ class TreeGridDataset(DGLBuiltinDataset):
def num_classes(self): def num_classes(self):
return 2 return 2
class BA2MotifDataset(DGLBuiltinDataset): class BA2MotifDataset(DGLBuiltinDataset):
r"""BA-2motifs dataset from `Parameterized Explainer for Graph Neural Network r"""BA-2motifs dataset from `Parameterized Explainer for Graph Neural Network
<https://arxiv.org/abs/2011.04573>`__ <https://arxiv.org/abs/2011.04573>`__
...@@ -705,26 +772,27 @@ class BA2MotifDataset(DGLBuiltinDataset): ...@@ -705,26 +772,27 @@ class BA2MotifDataset(DGLBuiltinDataset):
>>> g, label = dataset[0] >>> g, label = dataset[0]
>>> feat = g.ndata['feat'] >>> feat = g.ndata['feat']
""" """
def __init__(self,
raw_dir=None, def __init__(
force_reload=False, self, raw_dir=None, force_reload=False, verbose=True, transform=None
verbose=True, ):
transform=None): super(BA2MotifDataset, self).__init__(
super(BA2MotifDataset, self).__init__(name='BA-2motifs', name="BA-2motifs",
url=_get_dgl_url('dataset/BA-2motif.pkl'), url=_get_dgl_url("dataset/BA-2motif.pkl"),
raw_dir=raw_dir, raw_dir=raw_dir,
force_reload=force_reload, force_reload=force_reload,
verbose=verbose, verbose=verbose,
transform=transform) transform=transform,
)
def download(self): def download(self):
r""" Automatically download data.""" r"""Automatically download data."""
file_path = os.path.join(self.raw_dir, self.name + '.pkl') file_path = os.path.join(self.raw_dir, self.name + ".pkl")
download(self.url, path=file_path) download(self.url, path=file_path)
def process(self): def process(self):
file_path = os.path.join(self.raw_dir, self.name + '.pkl') file_path = os.path.join(self.raw_dir, self.name + ".pkl")
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
adjs, features, labels = pickle.load(f) adjs, features, labels = pickle.load(f)
self.graphs = [] self.graphs = []
...@@ -732,15 +800,17 @@ class BA2MotifDataset(DGLBuiltinDataset): ...@@ -732,15 +800,17 @@ class BA2MotifDataset(DGLBuiltinDataset):
for i in range(len(adjs)): for i in range(len(adjs)):
g = graph(adjs[i].nonzero()) g = graph(adjs[i].nonzero())
g.ndata['feat'] = F.zerocopy_from_numpy(features[i]) g.ndata["feat"] = F.zerocopy_from_numpy(features[i])
self.graphs.append(g) self.graphs.append(g)
@property @property
def graph_path(self): def graph_path(self):
return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name)) return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.name)
)
def save(self): def save(self):
label_dict = {'labels': self.labels} label_dict = {"labels": self.labels}
save_graphs(str(self.graph_path), self.graphs, label_dict) save_graphs(str(self.graph_path), self.graphs, label_dict)
def has_cache(self): def has_cache(self):
...@@ -748,7 +818,7 @@ class BA2MotifDataset(DGLBuiltinDataset): ...@@ -748,7 +818,7 @@ class BA2MotifDataset(DGLBuiltinDataset):
def load(self): def load(self):
self.graphs, label_dict = load_graphs(str(self.graph_path)) self.graphs, label_dict = load_graphs(str(self.graph_path))
self.labels = label_dict['labels'] self.labels = label_dict["labels"]
def __getitem__(self, idx): def __getitem__(self, idx):
g = self.graphs[idx] g = self.graphs[idx]
......
"""For Tensor Serialization""" """For Tensor Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
from ..ndarray import NDArray
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .._ffi.function import _init_api
from ..ndarray import NDArray
__all__ = ['save_tensors', "load_tensors"] __all__ = ["save_tensors", "load_tensors"]
_init_api("dgl.data.tensor_serialize") _init_api("dgl.data.tensor_serialize")
...@@ -12,11 +13,11 @@ _init_api("dgl.data.tensor_serialize") ...@@ -12,11 +13,11 @@ _init_api("dgl.data.tensor_serialize")
def save_tensors(filename, tensor_dict): def save_tensors(filename, tensor_dict):
""" """
Save dict of tensors to file Save dict of tensors to file
Parameters Parameters
---------- ----------
filename : str filename : str
File name to store dict of tensors. File name to store dict of tensors.
tensor_dict: dict of dgl NDArray or backend tensor tensor_dict: dict of dgl NDArray or backend tensor
Python dict using string as key and tensor as value Python dict using string as key and tensor as value
...@@ -36,19 +37,20 @@ def save_tensors(filename, tensor_dict): ...@@ -36,19 +37,20 @@ def save_tensors(filename, tensor_dict):
nd_dict[key] = value nd_dict[key] = value
else: else:
raise Exception( raise Exception(
"Dict value has to be backend tensor or dgl ndarray") "Dict value has to be backend tensor or dgl ndarray"
)
return _CAPI_SaveNDArrayDict(filename, nd_dict, is_empty_dict) return _CAPI_SaveNDArrayDict(filename, nd_dict, is_empty_dict)
def load_tensors(filename, return_dgl_ndarray=False): def load_tensors(filename, return_dgl_ndarray=False):
""" """
load dict of tensors from file load dict of tensors from file
Parameters Parameters
---------- ----------
filename : str filename : str
File name to load dict of tensors. File name to load dict of tensors.
return_dgl_ndarray: bool return_dgl_ndarray: bool
Whether return dict of dgl NDArrays or backend tensors Whether return dict of dgl NDArrays or backend tensors
......
"""Dataset utilities.""" """Dataset utilities."""
from __future__ import absolute_import from __future__ import absolute_import
import errno
import hashlib
import os import os
import pickle
import sys import sys
import hashlib
import warnings import warnings
import requests
import pickle
import errno
import numpy as np
import pickle import numpy as np
import errno import requests
from .graph_serialize import save_graphs, load_graphs, load_labels
from .tensor_serialize import save_tensors, load_tensors
from .. import backend as F from .. import backend as F
from .graph_serialize import load_graphs, load_labels, save_graphs
__all__ = ['loadtxt','download', 'check_sha1', 'extract_archive', from .tensor_serialize import load_tensors, save_tensors
'get_download_dir', 'Subset', 'split_dataset', 'save_graphs',
'load_graphs', 'load_labels', 'save_tensors', 'load_tensors', __all__ = [
'add_nodepred_split', "loadtxt",
"download",
"check_sha1",
"extract_archive",
"get_download_dir",
"Subset",
"split_dataset",
"save_graphs",
"load_graphs",
"load_labels",
"save_tensors",
"load_tensors",
"add_nodepred_split",
] ]
def loadtxt(path, delimiter, dtype=None): def loadtxt(path, delimiter, dtype=None):
try: try:
import pandas as pd import pandas as pd
df = pd.read_csv(path, delimiter=delimiter, header=None) df = pd.read_csv(path, delimiter=delimiter, header=None)
return df.values return df.values
except ImportError: except ImportError:
warnings.warn("Pandas is not installed, now using numpy.loadtxt to load data, " warnings.warn(
"which could be extremely slow. Accelerate by installing pandas") "Pandas is not installed, now using numpy.loadtxt to load data, "
"which could be extremely slow. Accelerate by installing pandas"
)
return np.loadtxt(path, delimiter=delimiter) return np.loadtxt(path, delimiter=delimiter)
def _get_dgl_url(file_url): def _get_dgl_url(file_url):
"""Get DGL online url for download.""" """Get DGL online url for download."""
dgl_repo_url = 'https://data.dgl.ai/' dgl_repo_url = "https://data.dgl.ai/"
repo_url = os.environ.get('DGL_REPO', dgl_repo_url) repo_url = os.environ.get("DGL_REPO", dgl_repo_url)
if repo_url[-1] != '/': if repo_url[-1] != "/":
repo_url = repo_url + '/' repo_url = repo_url + "/"
return repo_url + file_url return repo_url + file_url
...@@ -70,23 +82,35 @@ def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None): ...@@ -70,23 +82,35 @@ def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
Subsets for training, validation and test. Subsets for training, validation and test.
""" """
from itertools import accumulate from itertools import accumulate
if frac_list is None: if frac_list is None:
frac_list = [0.8, 0.1, 0.1] frac_list = [0.8, 0.1, 0.1]
frac_list = np.asarray(frac_list) frac_list = np.asarray(frac_list)
assert np.allclose(np.sum(frac_list), 1.), \ assert np.allclose(
'Expect frac_list sum to 1, got {:.4f}'.format(np.sum(frac_list)) np.sum(frac_list), 1.0
), "Expect frac_list sum to 1, got {:.4f}".format(np.sum(frac_list))
num_data = len(dataset) num_data = len(dataset)
lengths = (num_data * frac_list).astype(int) lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1]) lengths[-1] = num_data - np.sum(lengths[:-1])
if shuffle: if shuffle:
indices = np.random.RandomState( indices = np.random.RandomState(seed=random_state).permutation(num_data)
seed=random_state).permutation(num_data)
else: else:
indices = np.arange(num_data) indices = np.arange(num_data)
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)] return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(accumulate(lengths), lengths)
def download(url, path=None, overwrite=True, sha1_hash=None, retries=5, verify_ssl=True, log=True): ]
def download(
url,
path=None,
overwrite=True,
sha1_hash=None,
retries=5,
verify_ssl=True,
log=True,
):
"""Download a given URL. """Download a given URL.
Codes borrowed from mxnet/gluon/utils.py Codes borrowed from mxnet/gluon/utils.py
...@@ -117,45 +141,54 @@ def download(url, path=None, overwrite=True, sha1_hash=None, retries=5, verify_s ...@@ -117,45 +141,54 @@ def download(url, path=None, overwrite=True, sha1_hash=None, retries=5, verify_s
The file path of the downloaded file. The file path of the downloaded file.
""" """
if path is None: if path is None:
fname = url.split('/')[-1] fname = url.split("/")[-1]
# Empty filenames are invalid # Empty filenames are invalid
assert fname, 'Can\'t construct file-name from this URL. ' \ assert fname, (
'Please set the `path` option manually.' "Can't construct file-name from this URL. "
"Please set the `path` option manually."
)
else: else:
path = os.path.expanduser(path) path = os.path.expanduser(path)
if os.path.isdir(path): if os.path.isdir(path):
fname = os.path.join(path, url.split('/')[-1]) fname = os.path.join(path, url.split("/")[-1])
else: else:
fname = path fname = path
assert retries >= 0, "Number of retries should be at least 0" assert retries >= 0, "Number of retries should be at least 0"
if not verify_ssl: if not verify_ssl:
warnings.warn( warnings.warn(
'Unverified HTTPS request is being made (verify_ssl=False). ' "Unverified HTTPS request is being made (verify_ssl=False). "
'Adding certificate verification is strongly advised.') "Adding certificate verification is strongly advised."
)
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
if (
overwrite
or not os.path.exists(fname)
or (sha1_hash and not check_sha1(fname, sha1_hash))
):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
while retries+1 > 0: while retries + 1 > 0:
# Disable pyling too broad Exception # Disable pyling too broad Exception
# pylint: disable=W0703 # pylint: disable=W0703
try: try:
if log: if log:
print('Downloading %s from %s...' % (fname, url)) print("Downloading %s from %s..." % (fname, url))
r = requests.get(url, stream=True, verify=verify_ssl) r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200: if r.status_code != 200:
raise RuntimeError("Failed downloading url %s" % url) raise RuntimeError("Failed downloading url %s" % url)
with open(fname, 'wb') as f: with open(fname, "wb") as f:
for chunk in r.iter_content(chunk_size=1024): for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
f.write(chunk) f.write(chunk)
if sha1_hash and not check_sha1(fname, sha1_hash): if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match.' raise UserWarning(
' The repo may be outdated or download may be incomplete. ' "File {} is downloaded but the content hash does not match."
'If the "repo_url" is overridden, consider switching to ' " The repo may be outdated or download may be incomplete. "
'the default repo.'.format(fname)) 'If the "repo_url" is overridden, consider switching to '
"the default repo.".format(fname)
)
break break
except Exception as e: except Exception as e:
retries -= 1 retries -= 1
...@@ -163,8 +196,11 @@ def download(url, path=None, overwrite=True, sha1_hash=None, retries=5, verify_s ...@@ -163,8 +196,11 @@ def download(url, path=None, overwrite=True, sha1_hash=None, retries=5, verify_s
raise e raise e
else: else:
if log: if log:
print("download failed, retrying, {} attempt{} left" print(
.format(retries, 's' if retries > 1 else '')) "download failed, retrying, {} attempt{} left".format(
retries, "s" if retries > 1 else ""
)
)
return fname return fname
...@@ -187,7 +223,7 @@ def check_sha1(filename, sha1_hash): ...@@ -187,7 +223,7 @@ def check_sha1(filename, sha1_hash):
Whether the file content matches the expected hash. Whether the file content matches the expected hash.
""" """
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
with open(filename, 'rb') as f: with open(filename, "rb") as f:
while True: while True:
data = f.read(1048576) data = f.read(1048576)
if not data: if not data:
...@@ -212,24 +248,31 @@ def extract_archive(file, target_dir, overwrite=False): ...@@ -212,24 +248,31 @@ def extract_archive(file, target_dir, overwrite=False):
""" """
if os.path.exists(target_dir) and not overwrite: if os.path.exists(target_dir) and not overwrite:
return return
print('Extracting file to {}'.format(target_dir)) print("Extracting file to {}".format(target_dir))
if file.endswith('.tar.gz') or file.endswith('.tar') or file.endswith('.tgz'): if (
file.endswith(".tar.gz")
or file.endswith(".tar")
or file.endswith(".tgz")
):
import tarfile import tarfile
with tarfile.open(file, 'r') as archive:
with tarfile.open(file, "r") as archive:
archive.extractall(path=target_dir) archive.extractall(path=target_dir)
elif file.endswith('.gz'): elif file.endswith(".gz"):
import gzip import gzip
import shutil import shutil
with gzip.open(file, 'rb') as f_in:
with gzip.open(file, "rb") as f_in:
target_file = os.path.join(target_dir, os.path.basename(file)[:-3]) target_file = os.path.join(target_dir, os.path.basename(file)[:-3])
with open(target_file, 'wb') as f_out: with open(target_file, "wb") as f_out:
shutil.copyfileobj(f_in, f_out) shutil.copyfileobj(f_in, f_out)
elif file.endswith('.zip'): elif file.endswith(".zip"):
import zipfile import zipfile
with zipfile.ZipFile(file, 'r') as archive:
with zipfile.ZipFile(file, "r") as archive:
archive.extractall(path=target_dir) archive.extractall(path=target_dir)
else: else:
raise Exception('Unrecognized file type: ' + file) raise Exception("Unrecognized file type: " + file)
def get_download_dir(): def get_download_dir():
...@@ -240,12 +283,13 @@ def get_download_dir(): ...@@ -240,12 +283,13 @@ def get_download_dir():
dirname : str dirname : str
Path to the download directory Path to the download directory
""" """
default_dir = os.path.join(os.path.expanduser('~'), '.dgl') default_dir = os.path.join(os.path.expanduser("~"), ".dgl")
dirname = os.environ.get('DGL_DOWNLOAD_DIR', default_dir) dirname = os.environ.get("DGL_DOWNLOAD_DIR", default_dir)
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
return dirname return dirname
def makedirs(path): def makedirs(path):
try: try:
os.makedirs(os.path.expanduser(os.path.normpath(path))) os.makedirs(os.path.expanduser(os.path.normpath(path)))
...@@ -253,8 +297,9 @@ def makedirs(path): ...@@ -253,8 +297,9 @@ def makedirs(path):
if e.errno != errno.EEXIST and os.path.isdir(path): if e.errno != errno.EEXIST and os.path.isdir(path):
raise e raise e
def save_info(path, info): def save_info(path, info):
""" Save dataset related information into disk. """Save dataset related information into disk.
Parameters Parameters
---------- ----------
...@@ -263,12 +308,12 @@ def save_info(path, info): ...@@ -263,12 +308,12 @@ def save_info(path, info):
info : dict info : dict
A python dict storing information to save on disk. A python dict storing information to save on disk.
""" """
with open(path, "wb" ) as pf: with open(path, "wb") as pf:
pickle.dump(info, pf) pickle.dump(info, pf)
def load_info(path): def load_info(path):
""" Load dataset related information from disk. """Load dataset related information from disk.
Parameters Parameters
---------- ----------
...@@ -284,16 +329,28 @@ def load_info(path): ...@@ -284,16 +329,28 @@ def load_info(path):
info = pickle.load(pf) info = pickle.load(pf)
return info return info
def deprecate_property(old, new): def deprecate_property(old, new):
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new)) warnings.warn(
"Property {} will be deprecated, please use {} instead.".format(
old, new
)
)
def deprecate_function(old, new): def deprecate_function(old, new):
warnings.warn('Function {} will be deprecated, please use {} instead.'.format(old, new)) warnings.warn(
"Function {} will be deprecated, please use {} instead.".format(
old, new
)
)
def deprecate_class(old, new): def deprecate_class(old, new):
warnings.warn('Class {} will be deprecated, please use {} instead.'.format(old, new)) warnings.warn(
"Class {} will be deprecated, please use {} instead.".format(old, new)
)
def idx2mask(idx, len): def idx2mask(idx, len):
"""Create mask.""" """Create mask."""
...@@ -301,6 +358,7 @@ def idx2mask(idx, len): ...@@ -301,6 +358,7 @@ def idx2mask(idx, len):
mask[idx] = 1 mask[idx] = 1
return mask return mask
def generate_mask_tensor(mask): def generate_mask_tensor(mask):
"""Generate mask tensor according to different backend """Generate mask tensor according to different backend
For torch and tensorflow, it will create a bool tensor For torch and tensorflow, it will create a bool tensor
...@@ -310,12 +368,14 @@ def generate_mask_tensor(mask): ...@@ -310,12 +368,14 @@ def generate_mask_tensor(mask):
mask: numpy ndarray mask: numpy ndarray
input mask tensor input mask tensor
""" """
assert isinstance(mask, np.ndarray), "input for generate_mask_tensor" \ assert isinstance(mask, np.ndarray), (
"should be an numpy ndarray" "input for generate_mask_tensor" "should be an numpy ndarray"
if F.backend_name == 'mxnet': )
return F.tensor(mask, dtype=F.data_type_dict['float32']) if F.backend_name == "mxnet":
return F.tensor(mask, dtype=F.data_type_dict["float32"])
else: else:
return F.tensor(mask, dtype=F.data_type_dict['bool']) return F.tensor(mask, dtype=F.data_type_dict["bool"])
class Subset(object): class Subset(object):
"""Subset of a dataset at specified indices """Subset of a dataset at specified indices
...@@ -354,6 +414,7 @@ class Subset(object): ...@@ -354,6 +414,7 @@ class Subset(object):
""" """
return len(self.indices) return len(self.indices)
def add_nodepred_split(dataset, ratio, ntype=None): def add_nodepred_split(dataset, ratio, ntype=None):
"""Split the given dataset into training, validation and test sets for """Split the given dataset into training, validation and test sets for
transductive node predction task. transductive node predction task.
...@@ -384,16 +445,24 @@ def add_nodepred_split(dataset, ratio, ntype=None): ...@@ -384,16 +445,24 @@ def add_nodepred_split(dataset, ratio, ntype=None):
True True
""" """
if len(ratio) != 3: if len(ratio) != 3:
raise ValueError(f'Split ratio must be a float triplet but got {ratio}.') raise ValueError(
f"Split ratio must be a float triplet but got {ratio}."
)
for i in range(len(dataset)): for i in range(len(dataset)):
g = dataset[i] g = dataset[i]
n = g.num_nodes(ntype) n = g.num_nodes(ntype)
idx = np.arange(0, n) idx = np.arange(0, n)
np.random.shuffle(idx) np.random.shuffle(idx)
n_train, n_val, n_test = int(n * ratio[0]), int(n * ratio[1]), int(n * ratio[2]) n_train, n_val, n_test = (
int(n * ratio[0]),
int(n * ratio[1]),
int(n * ratio[2]),
)
train_mask = generate_mask_tensor(idx2mask(idx[:n_train], n)) train_mask = generate_mask_tensor(idx2mask(idx[:n_train], n))
val_mask = generate_mask_tensor(idx2mask(idx[n_train:n_train + n_val], n)) val_mask = generate_mask_tensor(
test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val:], n)) idx2mask(idx[n_train : n_train + n_val], n)
g.nodes[ntype].data['train_mask'] = train_mask )
g.nodes[ntype].data['val_mask'] = val_mask test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val :], n))
g.nodes[ntype].data['test_mask'] = test_mask g.nodes[ntype].data["train_mask"] = train_mask
g.nodes[ntype].data["val_mask"] = val_mask
g.nodes[ntype].data["test_mask"] = test_mask
"""Wiki-CS Dataset""" """Wiki-CS Dataset"""
import itertools import itertools
import os
import json import json
import os
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F
from ..convert import graph from ..convert import graph
from ..transforms import reorder_graph, to_bidirected
from .dgl_dataset import DGLBuiltinDataset from .dgl_dataset import DGLBuiltinDataset
from ..transforms import to_bidirected, reorder_graph from .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs
from .utils import generate_mask_tensor, load_graphs, save_graphs, _get_dgl_url
class WikiCSDataset(DGLBuiltinDataset): class WikiCSDataset(DGLBuiltinDataset):
...@@ -73,55 +75,64 @@ class WikiCSDataset(DGLBuiltinDataset): ...@@ -73,55 +75,64 @@ class WikiCSDataset(DGLBuiltinDataset):
(11701,) (11701,)
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None): def __init__(
_url = _get_dgl_url('dataset/wiki_cs.zip') self, raw_dir=None, force_reload=False, verbose=False, transform=None
super(WikiCSDataset, self).__init__(name='wiki_cs', ):
raw_dir=raw_dir, _url = _get_dgl_url("dataset/wiki_cs.zip")
url=_url, super(WikiCSDataset, self).__init__(
force_reload=force_reload, name="wiki_cs",
verbose=verbose, raw_dir=raw_dir,
transform=transform) url=_url,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self): def process(self):
"""process raw data to graph, labels and masks""" """process raw data to graph, labels and masks"""
with open(os.path.join(self.raw_path, 'data.json')) as f: with open(os.path.join(self.raw_path, "data.json")) as f:
data = json.load(f) data = json.load(f)
features = F.tensor(np.array(data['features']), dtype=F.float32) features = F.tensor(np.array(data["features"]), dtype=F.float32)
labels = F.tensor(np.array(data['labels']), dtype=F.int64) labels = F.tensor(np.array(data["labels"]), dtype=F.int64)
train_masks = np.array(data['train_masks'], dtype=bool).T train_masks = np.array(data["train_masks"], dtype=bool).T
val_masks = np.array(data['val_masks'], dtype=bool).T val_masks = np.array(data["val_masks"], dtype=bool).T
stopping_masks = np.array(data['stopping_masks'], dtype=bool).T stopping_masks = np.array(data["stopping_masks"], dtype=bool).T
test_mask = np.array(data['test_mask'], dtype=bool) test_mask = np.array(data["test_mask"], dtype=bool)
edges = [[(i, j) for j in js] for i, js in enumerate(data['links'])] edges = [[(i, j) for j in js] for i, js in enumerate(data["links"])]
edges = np.array(list(itertools.chain(*edges))) edges = np.array(list(itertools.chain(*edges)))
src, dst = edges[:, 0], edges[:, 1] src, dst = edges[:, 0], edges[:, 1]
g = graph((src, dst)) g = graph((src, dst))
g = to_bidirected(g) g = to_bidirected(g)
g.ndata['feat'] = features g.ndata["feat"] = features
g.ndata['label'] = labels g.ndata["label"] = labels
g.ndata['train_mask'] = generate_mask_tensor(train_masks) g.ndata["train_mask"] = generate_mask_tensor(train_masks)
g.ndata['val_mask'] = generate_mask_tensor(val_masks) g.ndata["val_mask"] = generate_mask_tensor(val_masks)
g.ndata['stopping_mask'] = generate_mask_tensor(stopping_masks) g.ndata["stopping_mask"] = generate_mask_tensor(stopping_masks)
g.ndata['test_mask'] = generate_mask_tensor(test_mask) g.ndata["test_mask"] = generate_mask_tensor(test_mask)
g = reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) g = reorder_graph(
g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
self._graph = g self._graph = g
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
return os.path.exists(graph_path) return os.path.exists(graph_path)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(graph_path, self._graph) save_graphs(graph_path, self._graph)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
g, _ = load_graphs(graph_path) g, _ = load_graphs(graph_path)
self._graph = g[0] self._graph = g[0]
...@@ -134,7 +145,7 @@ class WikiCSDataset(DGLBuiltinDataset): ...@@ -134,7 +145,7 @@ class WikiCSDataset(DGLBuiltinDataset):
return 1 return 1
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph object r"""Get graph object
Parameters Parameters
---------- ----------
......
"""Yelp Dataset""" """Yelp Dataset"""
import os
import json import json
import os
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
from .. import backend as F from .. import backend as F
from ..convert import from_scipy from ..convert import from_scipy
from ..transforms import reorder_graph from ..transforms import reorder_graph
from .dgl_dataset import DGLBuiltinDataset from .dgl_dataset import DGLBuiltinDataset
from .utils import generate_mask_tensor, load_graphs, save_graphs, _get_dgl_url from .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs
class YelpDataset(DGLBuiltinDataset): class YelpDataset(DGLBuiltinDataset):
r"""Yelp dataset for node classification from `GraphSAINT: Graph Sampling Based Inductive r"""Yelp dataset for node classification from `GraphSAINT: Graph Sampling Based Inductive
Learning Method <https://arxiv.org/abs/1907.04931>`_ Learning Method <https://arxiv.org/abs/1907.04931>`_
The task of this dataset is categorizing types of businesses based on customer reviewers and The task of this dataset is categorizing types of businesses based on customer reviewers and
friendship. friendship.
Yelp dataset statistics: Yelp dataset statistics:
- Nodes: 716,847 - Nodes: 716,847
- Edges: 13,954,819 - Edges: 13,954,819
- Number of classes: 100 (Multi-class) - Number of classes: 100 (Multi-class)
- Node feature size: 300 - Node feature size: 300
Parameters Parameters
---------- ----------
raw_dir : str raw_dir : str
Raw file directory to download/contains the input data directory. Raw file directory to download/contains the input data directory.
Default: ~/.dgl/ Default: ~/.dgl/
force_reload : bool force_reload : bool
Whether to reload the dataset. Whether to reload the dataset.
Default: False Default: False
verbose : bool verbose : bool
Whether to print out progress information. Whether to print out progress information.
Default: False Default: False
transform : callable, optional transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access. transformed before every access.
reorder : bool reorder : bool
Whether to reorder the graph using :func:`~dgl.reorder_graph`. Whether to reorder the graph using :func:`~dgl.reorder_graph`.
Default: False. Default: False.
Attributes Attributes
---------- ----------
num_classes : int num_classes : int
Number of node classes Number of node classes
Examples Examples
-------- --------
>>> dataset = YelpDataset() >>> dataset = YelpDataset()
>>> dataset.num_classes >>> dataset.num_classes
100 100
>>> g = dataset[0] >>> g = dataset[0]
>>> # get node feature >>> # get node feature
>>> feat = g.ndata['feat'] >>> feat = g.ndata['feat']
>>> # get node labels >>> # get node labels
>>> labels = g.ndata['label'] >>> labels = g.ndata['label']
>>> # get data split >>> # get data split
>>> train_mask = g.ndata['train_mask'] >>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask'] >>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask'] >>> test_mask = g.ndata['test_mask']
""" """
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None, def __init__(
reorder=False): self,
_url = _get_dgl_url('dataset/yelp.zip') raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
reorder=False,
):
_url = _get_dgl_url("dataset/yelp.zip")
self._reorder = reorder self._reorder = reorder
super(YelpDataset, self).__init__(name='yelp', super(YelpDataset, self).__init__(
raw_dir=raw_dir, name="yelp",
url=_url, raw_dir=raw_dir,
force_reload=force_reload, url=_url,
verbose=verbose, force_reload=force_reload,
transform=transform) verbose=verbose,
transform=transform,
)
def process(self): def process(self):
"""process raw data to graph, labels and masks""" """process raw data to graph, labels and masks"""
coo_adj = sp.load_npz(os.path.join(self.raw_path, "adj_full.npz")) coo_adj = sp.load_npz(os.path.join(self.raw_path, "adj_full.npz"))
g = from_scipy(coo_adj) g = from_scipy(coo_adj)
features = np.load(os.path.join(self.raw_path, 'feats.npy')) features = np.load(os.path.join(self.raw_path, "feats.npy"))
features = F.tensor(features, dtype=F.float32) features = F.tensor(features, dtype=F.float32)
y = [-1] * features.shape[0] y = [-1] * features.shape[0]
with open(os.path.join(self.raw_path, 'class_map.json')) as f: with open(os.path.join(self.raw_path, "class_map.json")) as f:
class_map = json.load(f) class_map = json.load(f)
for key, item in class_map.items(): for key, item in class_map.items():
y[int(key)] = item y[int(key)] = item
labels = F.tensor(np.array(y), dtype=F.int64) labels = F.tensor(np.array(y), dtype=F.int64)
with open(os.path.join(self.raw_path, 'role.json')) as f: with open(os.path.join(self.raw_path, "role.json")) as f:
role = json.load(f) role = json.load(f)
train_mask = np.zeros(features.shape[0], dtype=bool) train_mask = np.zeros(features.shape[0], dtype=bool)
train_mask[role['tr']] = True train_mask[role["tr"]] = True
val_mask = np.zeros(features.shape[0], dtype=bool) val_mask = np.zeros(features.shape[0], dtype=bool)
val_mask[role['va']] = True val_mask[role["va"]] = True
test_mask = np.zeros(features.shape[0], dtype=bool) test_mask = np.zeros(features.shape[0], dtype=bool)
test_mask[role['te']] = True test_mask[role["te"]] = True
g.ndata['feat'] = features g.ndata["feat"] = features
g.ndata['label'] = labels g.ndata["label"] = labels
g.ndata['train_mask'] = generate_mask_tensor(train_mask) g.ndata["train_mask"] = generate_mask_tensor(train_mask)
g.ndata['val_mask'] = generate_mask_tensor(val_mask) g.ndata["val_mask"] = generate_mask_tensor(val_mask)
g.ndata['test_mask'] = generate_mask_tensor(test_mask) g.ndata["test_mask"] = generate_mask_tensor(test_mask)
if self._reorder: if self._reorder:
self._graph = reorder_graph( self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False) g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
else: else:
self._graph = g self._graph = g
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
return os.path.exists(graph_path) return os.path.exists(graph_path)
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(graph_path, self._graph) save_graphs(graph_path, self._graph)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, "dgl_graph.bin")
g, _ = load_graphs(graph_path) g, _ = load_graphs(graph_path)
self._graph = g[0] self._graph = g[0]
...@@ -136,7 +150,7 @@ class YelpDataset(DGLBuiltinDataset): ...@@ -136,7 +150,7 @@ class YelpDataset(DGLBuiltinDataset):
return 1 return 1
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph object r"""Get graph object
Parameters Parameters
---------- ----------
......
"""Package for dataloaders and samplers.""" """Package for dataloaders and samplers."""
from .. import backend as F from .. import backend as F
from .neighbor_sampler import * from . import negative_sampler
from .base import *
from .cluster_gcn import * from .cluster_gcn import *
from .graphsaint import * from .graphsaint import *
from .neighbor_sampler import *
from .shadow import * from .shadow import *
from .base import *
from . import negative_sampler if F.get_preferred_backend() == "pytorch":
if F.get_preferred_backend() == 'pytorch':
from .dataloader import * from .dataloader import *
from .dist_dataloader import * from .dist_dataloader import *
"""Cluster-GCN samplers.""" """Cluster-GCN samplers."""
import os import os
import pickle import pickle
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F
from ..base import DGLError from ..base import DGLError
from ..partition import metis_partition_assignment from ..partition import metis_partition_assignment
from .base import set_node_lazy_features, set_edge_lazy_features, Sampler from .base import Sampler, set_edge_lazy_features, set_node_lazy_features
class ClusterGCNSampler(Sampler): class ClusterGCNSampler(Sampler):
"""Cluster sampler from `Cluster-GCN: An Efficient Algorithm for Training """Cluster sampler from `Cluster-GCN: An Efficient Algorithm for Training
...@@ -59,35 +61,60 @@ class ClusterGCNSampler(Sampler): ...@@ -59,35 +61,60 @@ class ClusterGCNSampler(Sampler):
>>> for subg in dataloader: >>> for subg in dataloader:
... train_on(subg) ... train_on(subg)
""" """
def __init__(self, g, k, cache_path='cluster_gcn.pkl', balance_ntypes=None,
balance_edges=False, mode='k-way', prefetch_ndata=None, def __init__(
prefetch_edata=None, output_device=None): self,
g,
k,
cache_path="cluster_gcn.pkl",
balance_ntypes=None,
balance_edges=False,
mode="k-way",
prefetch_ndata=None,
prefetch_edata=None,
output_device=None,
):
super().__init__() super().__init__()
if os.path.exists(cache_path): if os.path.exists(cache_path):
try: try:
with open(cache_path, 'rb') as f: with open(cache_path, "rb") as f:
self.partition_offset, self.partition_node_ids = pickle.load(f) (
self.partition_offset,
self.partition_node_ids,
) = pickle.load(f)
except (EOFError, TypeError, ValueError): except (EOFError, TypeError, ValueError):
raise DGLError( raise DGLError(
f'The contents in the cache file {cache_path} is invalid. ' f"The contents in the cache file {cache_path} is invalid. "
f'Please remove the cache file {cache_path} or specify another path.') f"Please remove the cache file {cache_path} or specify another path."
)
if len(self.partition_offset) != k + 1: if len(self.partition_offset) != k + 1:
raise DGLError( raise DGLError(
f'Number of partitions in the cache does not match the value of k. ' f"Number of partitions in the cache does not match the value of k. "
f'Please remove the cache file {cache_path} or specify another path.') f"Please remove the cache file {cache_path} or specify another path."
)
if len(self.partition_node_ids) != g.num_nodes(): if len(self.partition_node_ids) != g.num_nodes():
raise DGLError( raise DGLError(
f'Number of nodes in the cache does not match the given graph. ' f"Number of nodes in the cache does not match the given graph. "
f'Please remove the cache file {cache_path} or specify another path.') f"Please remove the cache file {cache_path} or specify another path."
)
else: else:
partition_ids = metis_partition_assignment( partition_ids = metis_partition_assignment(
g, k, balance_ntypes=balance_ntypes, balance_edges=balance_edges, mode=mode) g,
k,
balance_ntypes=balance_ntypes,
balance_edges=balance_edges,
mode=mode,
)
partition_ids = F.asnumpy(partition_ids) partition_ids = F.asnumpy(partition_ids)
partition_node_ids = np.argsort(partition_ids) partition_node_ids = np.argsort(partition_ids)
partition_size = F.zerocopy_from_numpy(np.bincount(partition_ids, minlength=k)) partition_size = F.zerocopy_from_numpy(
partition_offset = F.zerocopy_from_numpy(np.insert(np.cumsum(partition_size), 0, 0)) np.bincount(partition_ids, minlength=k)
)
partition_offset = F.zerocopy_from_numpy(
np.insert(np.cumsum(partition_size), 0, 0)
)
partition_node_ids = F.zerocopy_from_numpy(partition_node_ids) partition_node_ids = F.zerocopy_from_numpy(partition_node_ids)
with open(cache_path, 'wb') as f: with open(cache_path, "wb") as f:
pickle.dump((partition_offset, partition_node_ids), f) pickle.dump((partition_offset, partition_node_ids), f)
self.partition_offset = partition_offset self.partition_offset = partition_offset
self.partition_node_ids = partition_node_ids self.partition_node_ids = partition_node_ids
...@@ -96,7 +123,7 @@ class ClusterGCNSampler(Sampler): ...@@ -96,7 +123,7 @@ class ClusterGCNSampler(Sampler):
self.prefetch_edata = prefetch_edata or [] self.prefetch_edata = prefetch_edata or []
self.output_device = output_device self.output_device = output_device
def sample(self, g, partition_ids): # pylint: disable=arguments-differ def sample(self, g, partition_ids): # pylint: disable=arguments-differ
"""Sampling function. """Sampling function.
Parameters Parameters
...@@ -111,10 +138,18 @@ class ClusterGCNSampler(Sampler): ...@@ -111,10 +138,18 @@ class ClusterGCNSampler(Sampler):
DGLGraph DGLGraph
The sampled subgraph. The sampled subgraph.
""" """
node_ids = F.cat([ node_ids = F.cat(
self.partition_node_ids[self.partition_offset[i]:self.partition_offset[i+1]] [
for i in F.asnumpy(partition_ids)], 0) self.partition_node_ids[
sg = g.subgraph(node_ids, relabel_nodes=True, output_device=self.output_device) self.partition_offset[i] : self.partition_offset[i + 1]
]
for i in F.asnumpy(partition_ids)
],
0,
)
sg = g.subgraph(
node_ids, relabel_nodes=True, output_device=self.output_device
)
set_node_lazy_features(sg, self.prefetch_ndata) set_node_lazy_features(sg, self.prefetch_ndata)
set_edge_lazy_features(sg, self.prefetch_edata) set_edge_lazy_features(sg, self.prefetch_edata)
return sg return sg
"""GraphSAINT samplers.""" """GraphSAINT samplers."""
from ..base import DGLError from ..base import DGLError
from ..random import choice from ..random import choice
from ..sampling import random_walk, pack_traces from ..sampling import pack_traces, random_walk
from .base import set_node_lazy_features, set_edge_lazy_features, Sampler from .base import Sampler, set_edge_lazy_features, set_node_lazy_features
try: try:
import torch import torch
except ImportError: except ImportError:
pass pass
class SAINTSampler(Sampler): class SAINTSampler(Sampler):
"""Random node/edge/walk sampler from """Random node/edge/walk sampler from
`GraphSAINT: Graph Sampling Based Inductive Learning Method `GraphSAINT: Graph Sampling Based Inductive Learning Method
...@@ -65,18 +66,28 @@ class SAINTSampler(Sampler): ...@@ -65,18 +66,28 @@ class SAINTSampler(Sampler):
>>> for subg in dataloader: >>> for subg in dataloader:
... train_on(subg) ... train_on(subg)
""" """
def __init__(self, mode, budget, cache=True, prefetch_ndata=None,
prefetch_edata=None, output_device='cpu'): def __init__(
self,
mode,
budget,
cache=True,
prefetch_ndata=None,
prefetch_edata=None,
output_device="cpu",
):
super().__init__() super().__init__()
self.budget = budget self.budget = budget
if mode == 'node': if mode == "node":
self.sampler = self.node_sampler self.sampler = self.node_sampler
elif mode == 'edge': elif mode == "edge":
self.sampler = self.edge_sampler self.sampler = self.edge_sampler
elif mode == 'walk': elif mode == "walk":
self.sampler = self.walk_sampler self.sampler = self.walk_sampler
else: else:
raise DGLError(f"Expect mode to be 'node', 'edge' or 'walk', got {mode}.") raise DGLError(
f"Expect mode to be 'node', 'edge' or 'walk', got {mode}."
)
self.cache = cache self.cache = cache
self.prob = None self.prob = None
...@@ -95,8 +106,11 @@ class SAINTSampler(Sampler): ...@@ -95,8 +106,11 @@ class SAINTSampler(Sampler):
prob = g.out_degrees().float().clamp(min=1) prob = g.out_degrees().float().clamp(min=1)
if self.cache: if self.cache:
self.prob = prob self.prob = prob
return torch.multinomial(prob, num_samples=self.budget, return (
replacement=True).unique().type(g.idtype) torch.multinomial(prob, num_samples=self.budget, replacement=True)
.unique()
.type(g.idtype)
)
def edge_sampler(self, g): def edge_sampler(self, g):
"""Node ID sampler for random edge sampler""" """Node ID sampler for random edge sampler"""
...@@ -107,11 +121,13 @@ class SAINTSampler(Sampler): ...@@ -107,11 +121,13 @@ class SAINTSampler(Sampler):
in_deg = g.in_degrees().float().clamp(min=1) in_deg = g.in_degrees().float().clamp(min=1)
out_deg = g.out_degrees().float().clamp(min=1) out_deg = g.out_degrees().float().clamp(min=1)
# We can reduce the sample space by half if graphs are always symmetric. # We can reduce the sample space by half if graphs are always symmetric.
prob = 1. / in_deg[dst.long()] + 1. / out_deg[src.long()] prob = 1.0 / in_deg[dst.long()] + 1.0 / out_deg[src.long()]
prob /= prob.sum() prob /= prob.sum()
if self.cache: if self.cache:
self.prob = prob self.prob = prob
sampled_edges = torch.unique(choice(len(prob), size=self.budget, prob=prob)) sampled_edges = torch.unique(
choice(len(prob), size=self.budget, prob=prob)
)
sampled_nodes = torch.cat([src[sampled_edges], dst[sampled_edges]]) sampled_nodes = torch.cat([src[sampled_edges], dst[sampled_edges]])
return sampled_nodes.unique().type(g.idtype) return sampled_nodes.unique().type(g.idtype)
...@@ -139,7 +155,9 @@ class SAINTSampler(Sampler): ...@@ -139,7 +155,9 @@ class SAINTSampler(Sampler):
The sampled subgraph. The sampled subgraph.
""" """
node_ids = self.sampler(g) node_ids = self.sampler(g)
sg = g.subgraph(node_ids, relabel_nodes=True, output_device=self.output_device) sg = g.subgraph(
node_ids, relabel_nodes=True, output_device=self.output_device
)
set_node_lazy_features(sg, self.prefetch_ndata) set_node_lazy_features(sg, self.prefetch_ndata)
set_edge_lazy_features(sg, self.prefetch_edata) set_edge_lazy_features(sg, self.prefetch_edata)
return sg return sg
"""Negative samplers""" """Negative samplers"""
from collections.abc import Mapping from collections.abc import Mapping
from .. import backend as F from .. import backend as F
class _BaseNegativeSampler(object): class _BaseNegativeSampler(object):
def _generate(self, g, eids, canonical_etype): def _generate(self, g, eids, canonical_etype):
raise NotImplementedError raise NotImplementedError
...@@ -25,12 +27,14 @@ class _BaseNegativeSampler(object): ...@@ -25,12 +27,14 @@ class _BaseNegativeSampler(object):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()} eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
neg_pair = {k: self._generate(g, v, k) for k, v in eids.items()} neg_pair = {k: self._generate(g, v, k) for k, v in eids.items()}
else: else:
assert len(g.canonical_etypes) == 1, \ assert (
'please specify a dict of etypes and ids for graphs with multiple edge types' len(g.canonical_etypes) == 1
), "please specify a dict of etypes and ids for graphs with multiple edge types"
neg_pair = self._generate(g, eids, g.canonical_etypes[0]) neg_pair = self._generate(g, eids, g.canonical_etypes[0])
return neg_pair return neg_pair
class PerSourceUniform(_BaseNegativeSampler): class PerSourceUniform(_BaseNegativeSampler):
"""Negative sampler that randomly chooses negative destination nodes """Negative sampler that randomly chooses negative destination nodes
for each source node according to a uniform distribution. for each source node according to a uniform distribution.
...@@ -52,6 +56,7 @@ class PerSourceUniform(_BaseNegativeSampler): ...@@ -52,6 +56,7 @@ class PerSourceUniform(_BaseNegativeSampler):
>>> neg_sampler(g, torch.tensor([0, 1])) >>> neg_sampler(g, torch.tensor([0, 1]))
(tensor([0, 0, 1, 1]), tensor([1, 0, 2, 3])) (tensor([0, 0, 1, 1]), tensor([1, 0, 2, 3]))
""" """
def __init__(self, k): def __init__(self, k):
self.k = k self.k = k
...@@ -66,9 +71,11 @@ class PerSourceUniform(_BaseNegativeSampler): ...@@ -66,9 +71,11 @@ class PerSourceUniform(_BaseNegativeSampler):
dst = F.randint(shape, dtype, ctx, 0, g.num_nodes(vtype)) dst = F.randint(shape, dtype, ctx, 0, g.num_nodes(vtype))
return src, dst return src, dst
# Alias # Alias
Uniform = PerSourceUniform Uniform = PerSourceUniform
class GlobalUniform(_BaseNegativeSampler): class GlobalUniform(_BaseNegativeSampler):
"""Negative sampler that randomly chooses negative source-destination pairs according """Negative sampler that randomly chooses negative source-destination pairs according
to a uniform distribution. to a uniform distribution.
...@@ -104,6 +111,7 @@ class GlobalUniform(_BaseNegativeSampler): ...@@ -104,6 +111,7 @@ class GlobalUniform(_BaseNegativeSampler):
>>> neg_sampler(g, torch.LongTensor([0, 1])) >>> neg_sampler(g, torch.LongTensor([0, 1]))
(tensor([0, 1, 3, 2]), tensor([2, 0, 2, 1])) (tensor([0, 1, 3, 2]), tensor([2, 0, 2, 1]))
""" """
def __init__(self, k, exclude_self_loops=True, replace=False): def __init__(self, k, exclude_self_loops=True, replace=False):
self.k = k self.k = k
self.exclude_self_loops = exclude_self_loops self.exclude_self_loops = exclude_self_loops
...@@ -111,4 +119,8 @@ class GlobalUniform(_BaseNegativeSampler): ...@@ -111,4 +119,8 @@ class GlobalUniform(_BaseNegativeSampler):
def _generate(self, g, eids, canonical_etype): def _generate(self, g, eids, canonical_etype):
return g.global_uniform_negative_sampling( return g.global_uniform_negative_sampling(
len(eids) * self.k, self.exclude_self_loops, self.replace, canonical_etype) len(eids) * self.k,
self.exclude_self_loops,
self.replace,
canonical_etype,
)
""" """
This package contains DistGNN and Libra based graph partitioning tools. This package contains DistGNN and Libra based graph partitioning tools.
""" """
from . import partition from . import partition, tools
from . import tools
...@@ -18,17 +18,21 @@ from Xie et al. ...@@ -18,17 +18,21 @@ from Xie et al.
# Nesreen K. Ahmed <nesreen.k.ahmed@intel.com> # Nesreen K. Ahmed <nesreen.k.ahmed@intel.com>
# \cite Distributed Power-law Graph Computing: Theoretical and Empirical Analysis # \cite Distributed Power-law Graph Computing: Theoretical and Empirical Analysis
import json
import os import os
import time import time
import json
import torch as th import torch as th
from dgl import DGLGraph from dgl import DGLGraph
from dgl.sparse import libra_vertex_cut
from dgl.sparse import libra2dgl_build_dict
from dgl.sparse import libra2dgl_set_lr
from dgl.sparse import libra2dgl_build_adjlist
from dgl.data.utils import save_graphs, save_tensors
from dgl.base import DGLError from dgl.base import DGLError
from dgl.data.utils import save_graphs, save_tensors
from dgl.sparse import (
libra2dgl_build_adjlist,
libra2dgl_build_dict,
libra2dgl_set_lr,
libra_vertex_cut,
)
def libra_partition(num_community, G, resultdir): def libra_partition(num_community, G, resultdir):
...@@ -55,8 +59,8 @@ def libra_partition(num_community, G, resultdir): ...@@ -55,8 +59,8 @@ def libra_partition(num_community, G, resultdir):
3. The folder also contains a json file which contains partitions' information. 3. The folder also contains a json file which contains partitions' information.
""" """
num_nodes = G.number_of_nodes() # number of nodes num_nodes = G.number_of_nodes() # number of nodes
num_edges = G.number_of_edges() # number of edges num_edges = G.number_of_edges() # number of edges
print("Number of nodes in the graph: ", num_nodes) print("Number of nodes in the graph: ", num_nodes)
print("Number of edges in the graph: ", num_edges) print("Number of edges in the graph: ", num_edges)
...@@ -79,12 +83,23 @@ def libra_partition(num_community, G, resultdir): ...@@ -79,12 +83,23 @@ def libra_partition(num_community, G, resultdir):
## call to C/C++ code ## call to C/C++ code
out = th.zeros(u_t.shape[0], dtype=th.int32) out = th.zeros(u_t.shape[0], dtype=th.int32)
libra_vertex_cut(num_community, node_degree, edgenum_unassigned, community_weights, libra_vertex_cut(
u_t, v_t, weight_, out, num_nodes, num_edges, resultdir) num_community,
node_degree,
edgenum_unassigned,
community_weights,
u_t,
v_t,
weight_,
out,
num_nodes,
num_edges,
resultdir,
)
print("Max partition size: ", int(community_weights.max())) print("Max partition size: ", int(community_weights.max()))
print(" ** Converting libra partitions to dgl graphs **") print(" ** Converting libra partitions to dgl graphs **")
fsize = int(community_weights.max()) + 1024 ## max edges in partition fsize = int(community_weights.max()) + 1024 ## max edges in partition
# print("fsize: ", fsize, flush=True) # print("fsize: ", fsize, flush=True)
node_map = th.zeros(num_community, dtype=th.int64) node_map = th.zeros(num_community, dtype=th.int64)
...@@ -110,8 +125,20 @@ def libra_partition(num_community, G, resultdir): ...@@ -110,8 +125,20 @@ def libra_partition(num_community, G, resultdir):
## building node, parition dictionary ## building node, parition dictionary
## Assign local node ids and mapping to global node ids ## Assign local node ids and mapping to global node ids
ret = libra2dgl_build_dict(a_t, b_t, indices, ldt_key, gdt_key, gdt_value, ret = libra2dgl_build_dict(
node_map, offset, num_community, i, fsize, resultdir) a_t,
b_t,
indices,
ldt_key,
gdt_key,
gdt_value,
node_map,
offset,
num_community,
i,
fsize,
resultdir,
)
num_nodes_partition = int(ret[0]) num_nodes_partition = int(ret[0])
num_edges_partition = int(ret[1]) num_edges_partition = int(ret[1])
...@@ -123,23 +150,25 @@ def libra_partition(num_community, G, resultdir): ...@@ -123,23 +150,25 @@ def libra_partition(num_community, G, resultdir):
## fixing lr - 1-level tree for the split-nodes ## fixing lr - 1-level tree for the split-nodes
libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, num_community, num_nodes) libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, num_community, num_nodes)
######################################################## ########################################################
#graph_name = dataset # graph_name = dataset
graph_name = resultdir.split("_")[-1].split("/")[0] graph_name = resultdir.split("_")[-1].split("/")[0]
part_method = 'Libra' part_method = "Libra"
num_parts = num_community ## number of paritions/communities num_parts = num_community ## number of paritions/communities
num_hops = 0 num_hops = 0
node_map_val = node_map.tolist() node_map_val = node_map.tolist()
edge_map_val = 0 edge_map_val = 0
out_path = resultdir out_path = resultdir
part_metadata = {'graph_name': graph_name, part_metadata = {
'num_nodes': G.number_of_nodes(), "graph_name": graph_name,
'num_edges': G.number_of_edges(), "num_nodes": G.number_of_nodes(),
'part_method': part_method, "num_edges": G.number_of_edges(),
'num_parts': num_parts, "part_method": part_method,
'halo_hops': num_hops, "num_parts": num_parts,
'node_map': node_map_val, "halo_hops": num_hops,
'edge_map': edge_map_val} "node_map": node_map_val,
"edge_map": edge_map_val,
}
############################################################ ############################################################
for i in range(num_community): for i in range(num_community):
...@@ -151,18 +180,18 @@ def libra_partition(num_community, G, resultdir): ...@@ -151,18 +180,18 @@ def libra_partition(num_community, G, resultdir):
ldt = ldt_ar[0] ldt = ldt_ar[0]
try: try:
feat = G.ndata['feat'] feat = G.ndata["feat"]
except KeyError: except KeyError:
feat = G.ndata['features'] feat = G.ndata["features"]
try: try:
labels = G.ndata['label'] labels = G.ndata["label"]
except KeyError: except KeyError:
labels = G.ndata['labels'] labels = G.ndata["labels"]
trainm = G.ndata['train_mask'].int() trainm = G.ndata["train_mask"].int()
testm = G.ndata['test_mask'].int() testm = G.ndata["test_mask"].int()
valm = G.ndata['val_mask'].int() valm = G.ndata["val_mask"].int()
feat_size = feat.shape[1] feat_size = feat.shape[1]
gfeat = th.zeros([num_nodes_partition, feat_size], dtype=feat.dtype) gfeat = th.zeros([num_nodes_partition, feat_size], dtype=feat.dtype)
...@@ -174,21 +203,41 @@ def libra_partition(num_community, G, resultdir): ...@@ -174,21 +203,41 @@ def libra_partition(num_community, G, resultdir):
## build remote node databse per local node ## build remote node databse per local node
## gather feats, train, test, val, and labels for each partition ## gather feats, train, test, val, and labels for each partition
libra2dgl_build_adjlist(feat, gfeat, adj, inner_node, ldt, gdt_key, libra2dgl_build_adjlist(
gdt_value, node_map, lr_t, lrtensor, num_nodes_partition, feat,
num_community, i, feat_size, labels, trainm, testm, valm, gfeat,
glabels, gtrainm, gtestm, gvalm, feat.shape[0]) adj,
inner_node,
ldt,
g.ndata['adj'] = adj ## database of remote clones gdt_key,
g.ndata['inner_node'] = inner_node ## split node '0' else '1' gdt_value,
g.ndata['feat'] = gfeat ## gathered features node_map,
g.ndata['lf'] = lr_t ## 1-level tree among split nodes lr_t,
lrtensor,
g.ndata['label'] = glabels num_nodes_partition,
g.ndata['train_mask'] = gtrainm num_community,
g.ndata['test_mask'] = gtestm i,
g.ndata['val_mask'] = gvalm feat_size,
labels,
trainm,
testm,
valm,
glabels,
gtrainm,
gtestm,
gvalm,
feat.shape[0],
)
g.ndata["adj"] = adj ## database of remote clones
g.ndata["inner_node"] = inner_node ## split node '0' else '1'
g.ndata["feat"] = gfeat ## gathered features
g.ndata["lf"] = lr_t ## 1-level tree among split nodes
g.ndata["label"] = glabels
g.ndata["train_mask"] = gtrainm
g.ndata["test_mask"] = gtestm
g.ndata["val_mask"] = gvalm
# Validation code, run only small graphs # Validation code, run only small graphs
# for l in range(num_nodes_partition): # for l in range(num_nodes_partition):
...@@ -207,9 +256,11 @@ def libra_partition(num_community, G, resultdir): ...@@ -207,9 +256,11 @@ def libra_partition(num_community, G, resultdir):
node_feat_file = os.path.join(part_dir, "node_feat.dgl") node_feat_file = os.path.join(part_dir, "node_feat.dgl")
edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") edge_feat_file = os.path.join(part_dir, "edge_feat.dgl")
part_graph_file = os.path.join(part_dir, "graph.dgl") part_graph_file = os.path.join(part_dir, "graph.dgl")
part_metadata['part-{}'.format(part_id)] = {'node_feats': node_feat_file, part_metadata["part-{}".format(part_id)] = {
'edge_feats': edge_feat_file, "node_feats": node_feat_file,
'part_graph': part_graph_file} "edge_feats": edge_feat_file,
"part_graph": part_graph_file,
}
os.makedirs(part_dir, mode=0o775, exist_ok=True) os.makedirs(part_dir, mode=0o775, exist_ok=True)
save_tensors(node_feat_file, part.ndata) save_tensors(node_feat_file, part.ndata)
save_graphs(part_graph_file, [part]) save_graphs(part_graph_file, [part])
...@@ -219,7 +270,7 @@ def libra_partition(num_community, G, resultdir): ...@@ -219,7 +270,7 @@ def libra_partition(num_community, G, resultdir):
del ldt del ldt
del ldt_ar[0] del ldt_ar[0]
with open('{}/{}.json'.format(out_path, graph_name), 'w') as outfile: with open("{}/{}.json".format(out_path, graph_name), "w") as outfile:
json.dump(part_metadata, outfile, sort_keys=True, indent=4) json.dump(part_metadata, outfile, sort_keys=True, indent=4)
print("Conversion libra2dgl completed !!!") print("Conversion libra2dgl completed !!!")
...@@ -270,7 +321,9 @@ def partition_graph(num_community, G, resultdir): ...@@ -270,7 +321,9 @@ def partition_graph(num_community, G, resultdir):
raise DGLError("Error: Could not create directory: ", resultdir) raise DGLError("Error: Could not create directory: ", resultdir)
tic = time.time() tic = time.time()
print("####################################################################") print(
"####################################################################"
)
print("Executing parititons: ", num_community) print("Executing parititons: ", num_community)
ltic = time.time() ltic = time.time()
try: try:
...@@ -283,9 +336,18 @@ def partition_graph(num_community, G, resultdir): ...@@ -283,9 +336,18 @@ def partition_graph(num_community, G, resultdir):
libra_partition(num_community, G, resultdir) libra_partition(num_community, G, resultdir)
ltoc = time.time() ltoc = time.time()
print("Time taken by {} partitions {:0.4f} sec".format(num_community, ltoc - ltic)) print(
"Time taken by {} partitions {:0.4f} sec".format(
num_community, ltoc - ltic
)
)
print() print()
toc = time.time() toc = time.time()
print("Generated ", num_community, " partitions in {:0.4f} sec".format(toc - tic), flush=True) print(
"Generated ",
num_community,
" partitions in {:0.4f} sec".format(toc - tic),
flush=True,
)
print("Partitioning completed successfully !!!") print("Partitioning completed successfully !!!")
...@@ -7,13 +7,16 @@ Copyright (c) 2021 Intel Corporation ...@@ -7,13 +7,16 @@ Copyright (c) 2021 Intel Corporation
import os import os
import random import random
import requests import requests
from scipy.io import mmread
import torch as th import torch as th
from scipy.io import mmread
import dgl import dgl
from dgl.base import DGLError from dgl.base import DGLError
from dgl.data.utils import load_graphs, save_graphs, save_tensors from dgl.data.utils import load_graphs, save_graphs, save_tensors
def rep_per_node(prefix, num_community): def rep_per_node(prefix, num_community):
""" """
Used on Libra partitioned data. Used on Libra partitioned data.
...@@ -24,17 +27,17 @@ def rep_per_node(prefix, num_community): ...@@ -24,17 +27,17 @@ def rep_per_node(prefix, num_community):
prefix: Partition folder location (contains replicationlist.csv) prefix: Partition folder location (contains replicationlist.csv)
num_community: number of partitions or communities num_community: number of partitions or communities
""" """
ifile = os.path.join(prefix, 'replicationlist.csv') ifile = os.path.join(prefix, "replicationlist.csv")
fhandle = open(ifile, "r") fhandle = open(ifile, "r")
r_dt = {} r_dt = {}
fline = fhandle.readline() ## reading first line, contains the comment. fline = fhandle.readline() ## reading first line, contains the comment.
print(fline) print(fline)
for line in fhandle: for line in fhandle:
if line[0] == '#': if line[0] == "#":
raise DGLError("[Bug] Read Hash char in rep_per_node func.") raise DGLError("[Bug] Read Hash char in rep_per_node func.")
node = line.strip('\n') node = line.strip("\n")
if r_dt.get(node, -100) == -100: if r_dt.get(node, -100) == -100:
r_dt[node] = 1 r_dt[node] = 1
else: else:
...@@ -44,7 +47,9 @@ def rep_per_node(prefix, num_community): ...@@ -44,7 +47,9 @@ def rep_per_node(prefix, num_community):
## sanity checks ## sanity checks
for v in r_dt.values(): for v in r_dt.values():
if v >= num_community: if v >= num_community:
raise DGLError("[Bug] Unexpected event in rep_per_node() in tools.py.") raise DGLError(
"[Bug] Unexpected event in rep_per_node() in tools.py."
)
return r_dt return r_dt
...@@ -61,7 +66,9 @@ def download_proteins(): ...@@ -61,7 +66,9 @@ def download_proteins():
try: try:
req = requests.get(url) req = requests.get(url)
except: except:
raise DGLError("Error: Failed to download Proteins dataset!! Aborting..") raise DGLError(
"Error: Failed to download Proteins dataset!! Aborting.."
)
with open("proteins.mtx", "wb") as handle: with open("proteins.mtx", "wb") as handle:
handle.write(req.content) handle.write(req.content)
...@@ -73,7 +80,7 @@ def proteins_mtx2dgl(): ...@@ -73,7 +80,7 @@ def proteins_mtx2dgl():
""" """
print("Converting mtx2dgl..") print("Converting mtx2dgl..")
print("This might a take while..") print("This might a take while..")
a_mtx = mmread('proteins.mtx') a_mtx = mmread("proteins.mtx")
coo = a_mtx.tocoo() coo = a_mtx.tocoo()
u = th.tensor(coo.row, dtype=th.int64) u = th.tensor(coo.row, dtype=th.int64)
v = th.tensor(coo.col, dtype=th.int64) v = th.tensor(coo.col, dtype=th.int64)
...@@ -82,7 +89,7 @@ def proteins_mtx2dgl(): ...@@ -82,7 +89,7 @@ def proteins_mtx2dgl():
g.add_edges(u, v) g.add_edges(u, v)
n = g.number_of_nodes() n = g.number_of_nodes()
feat_size = 128 ## arbitrary number feat_size = 128 ## arbitrary number
feats = th.empty([n, feat_size], dtype=th.float32) feats = th.empty([n, feat_size], dtype=th.float32)
## arbitrary numbers ## arbitrary numbers
...@@ -108,11 +115,11 @@ def proteins_mtx2dgl(): ...@@ -108,11 +115,11 @@ def proteins_mtx2dgl():
for i in range(n): for i in range(n):
label[i] = random.choice(range(nlabels)) label[i] = random.choice(range(nlabels))
g.ndata['feat'] = feats g.ndata["feat"] = feats
g.ndata['train_mask'] = train_mask g.ndata["train_mask"] = train_mask
g.ndata['test_mask'] = test_mask g.ndata["test_mask"] = test_mask
g.ndata['val_mask'] = val_mask g.ndata["val_mask"] = val_mask
g.ndata['label'] = label g.ndata["label"] = label
return g return g
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment