Unverified Commit a208e886 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4680)



* [Misc] Black auto fix.

* fix pylint disable
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 29434e65
......@@ -6,14 +6,22 @@ https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip
"""
import os
import numpy as np
from .. import backend as F
from .dgl_dataset import DGLBuiltinDataset
from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info, download, extract_archive
from ..utils import retry_method_with_fix
from ..convert import graph as dgl_graph
from ..utils import retry_method_with_fix
from .dgl_dataset import DGLBuiltinDataset
from .utils import (
download,
extract_archive,
load_graphs,
load_info,
loadtxt,
save_graphs,
save_info,
)
class GINDataset(DGLBuiltinDataset):
......@@ -81,12 +89,20 @@ class GINDataset(DGLBuiltinDataset):
edata_schemes={})
"""
def __init__(self, name, self_loop, degree_as_nlabel=False,
raw_dir=None, force_reload=False, verbose=False, transform=None):
def __init__(
self,
name,
self_loop,
degree_as_nlabel=False,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
self._name = name # MUTAG
gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
self.ds_name = 'nig'
gin_url = "https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip"
self.ds_name = "nig"
self.self_loop = self_loop
self.graphs = []
......@@ -114,18 +130,23 @@ class GINDataset(DGLBuiltinDataset):
self.nattrs_flag = False
self.nlabels_flag = False
super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel),
raw_dir=raw_dir, force_reload=force_reload,
verbose=verbose, transform=transform)
super(GINDataset, self).__init__(
name=name,
url=gin_url,
hash_key=(name, self_loop, degree_as_nlabel),
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
@property
def raw_path(self):
return os.path.join(self.raw_dir, 'GINDataset')
return os.path.join(self.raw_dir, "GINDataset")
def download(self):
r""" Automatically download data and extract it.
"""
zip_file_path = os.path.join(self.raw_dir, 'GINDataset.zip')
r"""Automatically download data and extract it."""
zip_file_path = os.path.join(self.raw_dir, "GINDataset.zip")
download(self.url, path=zip_file_path)
extract_archive(zip_file_path, self.raw_path)
......@@ -153,21 +174,26 @@ class GINDataset(DGLBuiltinDataset):
return g, self.labels[idx]
def _file_path(self):
return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name))
return os.path.join(
self.raw_dir,
"GINDataset",
"dataset",
self.name,
"{}.txt".format(self.name),
)
def process(self):
""" Loads input dataset from dataset/NAME/NAME.txt file
"""
"""Loads input dataset from dataset/NAME/NAME.txt file"""
if self.verbose:
print('loading data...')
print("loading data...")
self.file = self._file_path()
with open(self.file, 'r') as f:
with open(self.file, "r") as f:
# line_1 == N, total number of graphs
self.N = int(f.readline().strip())
for i in range(self.N):
if (i + 1) % 10 == 0 and self.verbose is True:
print('processing graph {}...'.format(i + 1))
print("processing graph {}...".format(i + 1))
grow = f.readline().strip().split()
# line_2 == [n_nodes, l] is equal to
......@@ -201,7 +227,7 @@ class GINDataset(DGLBuiltinDataset):
nattr = [float(w) for w in nrow[tmp:]]
nattrs.append(nattr)
else:
raise Exception('edge number is incorrect!')
raise Exception("edge number is incorrect!")
# relabel nodes if it has labels
# if it doesn't have node labels, then every nrow[0]==0
......@@ -221,17 +247,18 @@ class GINDataset(DGLBuiltinDataset):
if (j + 1) % 10 == 0 and self.verbose is True:
print(
'processing node {} of graph {}...'.format(
j + 1, i + 1))
print('this node has {} edgs.'.format(
nrow[1]))
"processing node {} of graph {}...".format(
j + 1, i + 1
)
)
print("this node has {} edgs.".format(nrow[1]))
if nattrs != []:
nattrs = np.stack(nattrs)
g.ndata['attr'] = F.tensor(nattrs, F.float32)
g.ndata["attr"] = F.tensor(nattrs, F.float32)
self.nattrs_flag = True
g.ndata['label'] = F.tensor(nlabels)
g.ndata["label"] = F.tensor(nlabels)
if len(self.nlabel_dict) > 1:
self.nlabels_flag = True
......@@ -247,51 +274,59 @@ class GINDataset(DGLBuiltinDataset):
# if no attr
if not self.nattrs_flag:
if self.verbose:
print('there are no node features in this dataset!')
print("there are no node features in this dataset!")
# generate node attr by node degree
if self.degree_as_nlabel:
if self.verbose:
print('generate node features by node degree...')
print("generate node features by node degree...")
for g in self.graphs:
# actually this label shouldn't be updated
# in case users want to keep it
# but usually no features means no labels, fine.
g.ndata['label'] = g.in_degrees()
g.ndata["label"] = g.in_degrees()
# extracting unique node labels
# in case the labels/degrees are not continuous number
nlabel_set = set([])
for g in self.graphs:
nlabel_set = nlabel_set.union(
set([F.as_scalar(nl) for nl in g.ndata['label']]))
set([F.as_scalar(nl) for nl in g.ndata["label"]])
)
nlabel_set = list(nlabel_set)
is_label_valid = all([label in self.nlabel_dict for label in nlabel_set])
if is_label_valid and len(nlabel_set) == np.max(nlabel_set) + 1 and np.min(nlabel_set) == 0:
is_label_valid = all(
[label in self.nlabel_dict for label in nlabel_set]
)
if (
is_label_valid
and len(nlabel_set) == np.max(nlabel_set) + 1
and np.min(nlabel_set) == 0
):
# Note this is different from the author's implementation. In weihua916's implementation,
# the labels are relabeled anyway. But here we didn't relabel it if the labels are contiguous
# to make it consistent with the original dataset
label2idx = self.nlabel_dict
else:
label2idx = {
nlabel_set[i]: i
for i in range(len(nlabel_set))
}
label2idx = {nlabel_set[i]: i for i in range(len(nlabel_set))}
# generate node attr by node label
for g in self.graphs:
attr = np.zeros((
g.number_of_nodes(), len(label2idx)))
attr[range(g.number_of_nodes()), [label2idx[nl]
for nl in F.asnumpy(g.ndata['label']).tolist()]] = 1
g.ndata['attr'] = F.tensor(attr, F.float32)
attr = np.zeros((g.number_of_nodes(), len(label2idx)))
attr[
range(g.number_of_nodes()),
[
label2idx[nl]
for nl in F.asnumpy(g.ndata["label"]).tolist()
],
] = 1
g.ndata["attr"] = F.tensor(attr, F.float32)
# after load, get the #classes and #dim
self.gclasses = len(self.glabel_dict)
self.nclasses = len(self.nlabel_dict)
self.eclasses = len(self.elabel_dict)
self.dim_nfeats = len(self.graphs[0].ndata['attr'][0])
self.dim_nfeats = len(self.graphs[0].ndata["attr"][0])
if self.verbose:
print('Done.')
print("Done.")
print(
"""
-------- Data Statistics --------'
......@@ -306,64 +341,83 @@ class GINDataset(DGLBuiltinDataset):
Avg. of #Edges: %.2f
Graph Relabeled: %s
Node Relabeled: %s
Degree Relabeled(If degree_as_nlabel=True): %s \n """ % (
self.N, self.gclasses, self.n, self.nclasses,
self.dim_nfeats, self.m, self.eclasses,
self.n / self.N, self.m / self.N, self.glabel_dict,
self.nlabel_dict, self.ndegree_dict))
Degree Relabeled(If degree_as_nlabel=True): %s \n """
% (
self.N,
self.gclasses,
self.n,
self.nclasses,
self.dim_nfeats,
self.m,
self.eclasses,
self.n / self.N,
self.m / self.N,
self.glabel_dict,
self.nlabel_dict,
self.ndegree_dict,
)
)
def save(self):
graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
label_dict = {'labels': self.labels}
info_dict = {'N': self.N,
'n': self.n,
'm': self.m,
'self_loop': self.self_loop,
'gclasses': self.gclasses,
'nclasses': self.nclasses,
'eclasses': self.eclasses,
'dim_nfeats': self.dim_nfeats,
'degree_as_nlabel': self.degree_as_nlabel,
'glabel_dict': self.glabel_dict,
'nlabel_dict': self.nlabel_dict,
'elabel_dict': self.elabel_dict,
'ndegree_dict': self.ndegree_dict}
self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
)
label_dict = {"labels": self.labels}
info_dict = {
"N": self.N,
"n": self.n,
"m": self.m,
"self_loop": self.self_loop,
"gclasses": self.gclasses,
"nclasses": self.nclasses,
"eclasses": self.eclasses,
"dim_nfeats": self.dim_nfeats,
"degree_as_nlabel": self.degree_as_nlabel,
"glabel_dict": self.glabel_dict,
"nlabel_dict": self.nlabel_dict,
"elabel_dict": self.elabel_dict,
"ndegree_dict": self.ndegree_dict,
}
save_graphs(str(graph_path), self.graphs, label_dict)
save_info(str(info_path), info_dict)
def load(self):
graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
)
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))
self.graphs = graphs
self.labels = label_dict['labels']
self.N = info_dict['N']
self.n = info_dict['n']
self.m = info_dict['m']
self.self_loop = info_dict['self_loop']
self.gclasses = info_dict['gclasses']
self.nclasses = info_dict['nclasses']
self.eclasses = info_dict['eclasses']
self.dim_nfeats = info_dict['dim_nfeats']
self.glabel_dict = info_dict['glabel_dict']
self.nlabel_dict = info_dict['nlabel_dict']
self.elabel_dict = info_dict['elabel_dict']
self.ndegree_dict = info_dict['ndegree_dict']
self.degree_as_nlabel = info_dict['degree_as_nlabel']
self.labels = label_dict["labels"]
self.N = info_dict["N"]
self.n = info_dict["n"]
self.m = info_dict["m"]
self.self_loop = info_dict["self_loop"]
self.gclasses = info_dict["gclasses"]
self.nclasses = info_dict["nclasses"]
self.eclasses = info_dict["eclasses"]
self.dim_nfeats = info_dict["dim_nfeats"]
self.glabel_dict = info_dict["glabel_dict"]
self.nlabel_dict = info_dict["nlabel_dict"]
self.elabel_dict = info_dict["elabel_dict"]
self.ndegree_dict = info_dict["ndegree_dict"]
self.degree_as_nlabel = info_dict["degree_as_nlabel"]
def has_cache(self):
graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
self.save_path, "gin_{}_{}.bin".format(self.name, self.hash)
)
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
self.save_path, "gin_{}_{}.pkl".format(self.name, self.hash)
)
if os.path.exists(graph_path) and os.path.exists(info_path):
return True
return False
......
"""For Graph Serialization"""
from __future__ import absolute_import
import os
from ..base import dgl_warning, DGLError
from ..heterograph import DGLHeteroGraph
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F
from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object
from ..base import DGLError, dgl_warning
from ..heterograph import DGLHeteroGraph
from .heterograph_serialize import save_heterographs
_init_api("dgl.data.graph_serialize")
__all__ = ['save_graphs', "load_graphs", "load_labels"]
__all__ = ["save_graphs", "load_graphs", "load_labels"]
@register_object("graph_serialize.StorageMetaData")
......@@ -26,15 +28,18 @@ class StorageMetaData(ObjectBase):
def is_local_path(filepath):
return not (filepath.startswith("hdfs://") or
filepath.startswith("viewfs://") or
filepath.startswith("s3://"))
return not (
filepath.startswith("hdfs://")
or filepath.startswith("viewfs://")
or filepath.startswith("s3://")
)
def check_local_file_exists(filename):
if is_local_path(filename) and not os.path.exists(filename):
raise DGLError("File {} does not exist.".format(filename))
@register_object("graph_serialize.GraphData")
class GraphData(ObjectBase):
"""GraphData Object"""
......@@ -43,7 +48,9 @@ class GraphData(ObjectBase):
def create(g):
"""Create GraphData"""
# TODO(zihao): support serialize batched graph in the future.
assert g.batch_size == 1, "Batched DGLGraph is not supported for serialization"
assert (
g.batch_size == 1
), "Batched DGLGraph is not supported for serialization"
ghandle = g._graph
if len(g.ndata) != 0:
node_tensors = dict()
......@@ -64,8 +71,8 @@ class GraphData(ObjectBase):
def get_graph(self):
"""Get DGLHeteroGraph from GraphData"""
ghandle = _CAPI_GDataGraphHandle(self)
hgi =_CAPI_DGLAsHeteroGraph(ghandle)
g = DGLHeteroGraph(hgi, ['_U'], ['_E'])
hgi = _CAPI_DGLAsHeteroGraph(ghandle)
g = DGLHeteroGraph(hgi, ["_U"], ["_E"])
node_tensors_items = _CAPI_GDataNodeTensors(self).items()
edge_tensors_items = _CAPI_GDataEdgeTensors(self).items()
for k, v in node_tensors_items:
......@@ -120,18 +127,22 @@ def save_graphs(filename, g_list, labels=None):
# if it is local file, do some sanity check
if is_local_path(filename):
if os.path.isdir(filename):
raise DGLError("Filename {} is an existing directory.".format(filename))
raise DGLError(
"Filename {} is an existing directory.".format(filename)
)
f_path = os.path.dirname(filename)
if f_path and not os.path.exists(f_path):
os.makedirs(f_path)
g_sample = g_list[0] if isinstance(g_list, list) else g_list
if type(g_sample) == DGLHeteroGraph: # Doesn't support DGLHeteroGraph's derived class
if (
type(g_sample) == DGLHeteroGraph
): # Doesn't support DGLHeteroGraph's derived class
save_heterographs(filename, g_list, labels)
else:
raise DGLError(
"Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs.")
"Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs."
)
def load_graphs(filename, idx_list=None):
......@@ -176,7 +187,8 @@ def load_graphs(filename, idx_list=None):
if version == 1:
dgl_warning(
"You are loading a graph file saved by old version of dgl. \
Please consider saving it again with the current format.")
Please consider saving it again with the current format."
)
return load_graph_v1(filename, idx_list)
elif version == 2:
return load_graph_v2(filename, idx_list)
......@@ -195,7 +207,7 @@ def load_graph_v2(filename, idx_list=None):
def load_graph_v1(filename, idx_list=None):
""""Internal functions for loading DGLGraphs (V0)."""
""" "Internal functions for loading DGLGraphs (V0)."""
if idx_list is None:
idx_list = []
assert isinstance(idx_list, list)
......@@ -206,6 +218,7 @@ def load_graph_v1(filename, idx_list=None):
return [gdata.get_graph() for gdata in metadata.graph_data], label_dict
def load_labels(filename):
"""
Load label dict from file
......
"""For HeteroGraph Serialization"""
from __future__ import absolute_import
from ..heterograph import DGLHeteroGraph
from ..frame import Frame
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F
from .._ffi.function import _init_api
from .._ffi.object import ObjectBase, register_object
from ..container import convert_to_strmap
from ..frame import Frame
from ..heterograph import DGLHeteroGraph
_init_api("dgl.data.heterograph_serialize")
......@@ -24,9 +25,14 @@ def save_heterographs(filename, g_list, labels):
labels = {}
if isinstance(g_list, DGLHeteroGraph):
g_list = [g_list]
assert all([type(g) == DGLHeteroGraph for g in g_list]), "Invalid DGLHeteroGraph in g_list argument"
assert all(
[type(g) == DGLHeteroGraph for g in g_list]
), "Invalid DGLHeteroGraph in g_list argument"
gdata_list = [HeteroGraphData.create(g) for g in g_list]
_CAPI_SaveHeteroGraphData(filename, gdata_list, tensor_dict_to_ndarray_dict(labels))
_CAPI_SaveHeteroGraphData(
filename, gdata_list, tensor_dict_to_ndarray_dict(labels)
)
@register_object("heterograph_serialize.HeteroGraphData")
class HeteroGraphData(ObjectBase):
......@@ -40,7 +46,9 @@ class HeteroGraphData(ObjectBase):
edata_list.append(tensor_dict_to_ndarray_dict(g.edges[etype].data))
for ntype in g.ntypes:
ndata_list.append(tensor_dict_to_ndarray_dict(g.nodes[ntype].data))
return _CAPI_MakeHeteroGraphData(g._graph, ndata_list, edata_list, g.ntypes, g.etypes)
return _CAPI_MakeHeteroGraphData(
g._graph, ndata_list, edata_list, g.ntypes, g.etypes
)
def get_graph(self):
ntensor_list = list(_CAPI_GetNDataFromHeteroGraphData(self))
......@@ -51,11 +59,17 @@ class HeteroGraphData(ObjectBase):
nframes = []
eframes = []
for ntid, ntensor in enumerate(ntensor_list):
ndict = {ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i+1]) for i in range(0, len(ntensor), 2)}
ndict = {
ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i + 1])
for i in range(0, len(ntensor), 2)
}
nframes.append(Frame(ndict, num_rows=gidx.number_of_nodes(ntid)))
for etid, etensor in enumerate(etensor_list):
edict = {etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i+1]) for i in range(0, len(etensor), 2)}
edict = {
etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i + 1])
for i in range(0, len(etensor), 2)
}
eframes.append(Frame(edict, num_rows=gidx.number_of_edges(etid)))
return DGLHeteroGraph(gidx, ntype_names, etype_names, nframes, eframes)
"""ICEWS18 dataset for temporal graph"""
import numpy as np
import os
from .dgl_dataset import DGLBuiltinDataset
from .utils import loadtxt, _get_dgl_url, save_graphs, load_graphs
from ..convert import graph as dgl_graph
import numpy as np
from .. import backend as F
from ..convert import graph as dgl_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs, loadtxt, save_graphs
class ICEWS18Dataset(DGLBuiltinDataset):
r""" ICEWS18 dataset for temporal graph
r"""ICEWS18 dataset for temporal graph
Integrated Crisis Early Warning System (ICEWS18)
......@@ -65,21 +66,33 @@ class ICEWS18Dataset(DGLBuiltinDataset):
....
>>>
"""
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None):
def __init__(
self,
mode="train",
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid"
assert mode in ["train", "valid", "test"], "Mode not valid"
self.mode = mode
_url = _get_dgl_url('dataset/icews18.zip')
super(ICEWS18Dataset, self).__init__(name='ICEWS18',
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
_url = _get_dgl_url("dataset/icews18.zip")
super(ICEWS18Dataset, self).__init__(
name="ICEWS18",
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
data = loadtxt(os.path.join(self.save_path, '{}.txt'.format(self.mode)),
delimiter='\t').astype(np.int64)
data = loadtxt(
os.path.join(self.save_path, "{}.txt".format(self.mode)),
delimiter="\t",
).astype(np.int64)
num_nodes = 23033
# The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples.
......@@ -92,23 +105,31 @@ class ICEWS18Dataset(DGLBuiltinDataset):
edges = data[row_mask][:, [0, 2]]
rate = data[row_mask][:, 1]
g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64'])
g.edata["rel_type"] = F.tensor(
rate.reshape(-1, 1), dtype=F.data_type_dict["int64"]
)
self._graphs.append(g)
def has_cache(self):
graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
graph_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
return os.path.exists(graph_path)
def save(self):
graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
graph_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
save_graphs(graph_path, self._graphs)
def load(self):
graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
graph_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
self._graphs = load_graphs(graph_path)[0]
def __getitem__(self, idx):
r""" Get graph by index
r"""Get graph by index
Parameters
----------
......
"""A mini synthetic dataset for graph classification benchmark."""
import math, os
import math
import os
import networkx as nx
import numpy as np
from .dgl_dataset import DGLDataset
from .utils import save_graphs, load_graphs, makedirs
from .. import backend as F
from ..convert import from_networkx
from ..transforms import add_self_loop
from .dgl_dataset import DGLDataset
from .utils import load_graphs, makedirs, save_graphs
__all__ = ["MiniGCDataset"]
__all__ = ['MiniGCDataset']
class MiniGCDataset(DGLDataset):
"""The synthetic graph classification dataset class.
......@@ -78,17 +81,30 @@ class MiniGCDataset(DGLDataset):
edata_schemes={})
"""
def __init__(self, num_graphs, min_num_v, max_num_v, seed=0,
save_graph=True, force_reload=False, verbose=False, transform=None):
def __init__(
self,
num_graphs,
min_num_v,
max_num_v,
seed=0,
save_graph=True,
force_reload=False,
verbose=False,
transform=None,
):
self.num_graphs = num_graphs
self.min_num_v = min_num_v
self.max_num_v = max_num_v
self.seed = seed
self.save_graph = save_graph
super(MiniGCDataset, self).__init__(name="minigc", hash_key=(num_graphs, min_num_v, max_num_v, seed),
force_reload=force_reload,
verbose=verbose, transform=transform)
super(MiniGCDataset, self).__init__(
name="minigc",
hash_key=(num_graphs, min_num_v, max_num_v, seed),
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
self.graphs = []
......@@ -119,7 +135,9 @@ class MiniGCDataset(DGLDataset):
return g, self.labels[idx]
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
graph_path = os.path.join(
self.save_path, "dgl_graph_{}.bin".format(self.hash)
)
if os.path.exists(graph_path):
return True
......@@ -128,13 +146,17 @@ class MiniGCDataset(DGLDataset):
def save(self):
"""save the graph list and the labels"""
if self.save_graph:
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graphs, {'labels': self.labels})
graph_path = os.path.join(
self.save_path, "dgl_graph_{}.bin".format(self.hash)
)
save_graphs(str(graph_path), self.graphs, {"labels": self.labels})
def load(self):
graphs, label_dict = load_graphs(os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash)))
graphs, label_dict = load_graphs(
os.path.join(self.save_path, "dgl_graph_{}.bin".format(self.hash))
)
self.graphs = graphs
self.labels = label_dict['labels']
self.labels = label_dict["labels"]
@property
def num_classes(self):
......@@ -199,9 +221,11 @@ class MiniGCDataset(DGLDataset):
def _gen_grid(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
assert num_v >= 4, 'We require a grid graph to contain at least two ' \
'rows and two columns, thus 4 nodes, got {:d} ' \
'nodes'.format(num_v)
assert num_v >= 4, (
"We require a grid graph to contain at least two "
"rows and two columns, thus 4 nodes, got {:d} "
"nodes".format(num_v)
)
n_rows = np.random.randint(2, num_v // 2)
n_cols = num_v // n_rows
g = nx.grid_graph([n_rows, n_cols])
......
""" PPIDataset for inductive learning. """
import json
import numpy as np
import os
import networkx as nx
import numpy as np
from networkx.readwrite import json_graph
import os
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, save_graphs, save_info, load_info, load_graphs
from .. import backend as F
from ..convert import from_networkx
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs, load_info, save_graphs, save_info
class PPIDataset(DGLBuiltinDataset):
r""" Protein-Protein Interaction dataset for inductive node classification
r"""Protein-Protein Interaction dataset for inductive node classification
A toy Protein-Protein Interaction network dataset. The dataset contains
24 graphs. The average number of nodes per graph is 2372. Each node has
......@@ -65,36 +66,55 @@ class PPIDataset(DGLBuiltinDataset):
.... # your code here
>>>
"""
def __init__(self, mode='train', raw_dir=None, force_reload=False,
verbose=False, transform=None):
assert mode in ['train', 'valid', 'test']
def __init__(
self,
mode="train",
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
assert mode in ["train", "valid", "test"]
self.mode = mode
_url = _get_dgl_url('dataset/ppi.zip')
super(PPIDataset, self).__init__(name='ppi',
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
_url = _get_dgl_url("dataset/ppi.zip")
super(PPIDataset, self).__init__(
name="ppi",
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
graph_file = os.path.join(self.save_path, '{}_graph.json'.format(self.mode))
label_file = os.path.join(self.save_path, '{}_labels.npy'.format(self.mode))
feat_file = os.path.join(self.save_path, '{}_feats.npy'.format(self.mode))
graph_id_file = os.path.join(self.save_path, '{}_graph_id.npy'.format(self.mode))
graph_file = os.path.join(
self.save_path, "{}_graph.json".format(self.mode)
)
label_file = os.path.join(
self.save_path, "{}_labels.npy".format(self.mode)
)
feat_file = os.path.join(
self.save_path, "{}_feats.npy".format(self.mode)
)
graph_id_file = os.path.join(
self.save_path, "{}_graph_id.npy".format(self.mode)
)
g_data = json.load(open(graph_file))
self._labels = np.load(label_file)
self._feats = np.load(feat_file)
self.graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph = from_networkx(
nx.DiGraph(json_graph.node_link_graph(g_data))
)
graph_id = np.load(graph_id_file)
# lo, hi means the range of graph ids for different portion of the dataset,
# 20 graphs for training, 2 for validation and 2 for testing.
lo, hi = 1, 21
if self.mode == 'valid':
if self.mode == "valid":
lo, hi = 21, 23
elif self.mode == 'test':
elif self.mode == "test":
lo, hi = 23, 25
graph_masks = []
......@@ -103,34 +123,60 @@ class PPIDataset(DGLBuiltinDataset):
g_mask = np.where(graph_id == g_id)[0]
graph_masks.append(g_mask)
g = self.graph.subgraph(g_mask)
g.ndata['feat'] = F.tensor(self._feats[g_mask], dtype=F.data_type_dict['float32'])
g.ndata['label'] = F.tensor(self._labels[g_mask], dtype=F.data_type_dict['float32'])
g.ndata["feat"] = F.tensor(
self._feats[g_mask], dtype=F.data_type_dict["float32"]
)
g.ndata["label"] = F.tensor(
self._labels[g_mask], dtype=F.data_type_dict["float32"]
)
self.graphs.append(g)
def has_cache(self):
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode))
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode))
return os.path.exists(graph_list_path) and os.path.exists(g_path) and os.path.exists(info_path)
graph_list_path = os.path.join(
self.save_path, "{}_dgl_graph_list.bin".format(self.mode)
)
g_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
return (
os.path.exists(graph_list_path)
and os.path.exists(g_path)
and os.path.exists(info_path)
)
def save(self):
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode))
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode))
graph_list_path = os.path.join(
self.save_path, "{}_dgl_graph_list.bin".format(self.mode)
)
g_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
save_graphs(graph_list_path, self.graphs)
save_graphs(g_path, self.graph)
save_info(info_path, {'labels': self._labels, 'feats': self._feats})
save_info(info_path, {"labels": self._labels, "feats": self._feats})
def load(self):
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode))
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode))
graph_list_path = os.path.join(
self.save_path, "{}_dgl_graph_list.bin".format(self.mode)
)
g_path = os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.mode)
)
info_path = os.path.join(
self.save_path, "{}_info.pkl".format(self.mode)
)
self.graphs = load_graphs(graph_list_path)[0]
g, _ = load_graphs(g_path)
self.graph = g[0]
info = load_info(info_path)
self._labels = info['labels']
self._feats = info['feats']
self._labels = info["labels"]
self._feats = info["feats"]
@property
def num_labels(self):
......@@ -163,8 +209,7 @@ class PPIDataset(DGLBuiltinDataset):
class LegacyPPIDataset(PPIDataset):
"""Legacy version of PPI Dataset
"""
"""Legacy version of PPI Dataset"""
def __getitem__(self, item):
"""Get the item^th sample.
......@@ -183,4 +228,4 @@ class LegacyPPIDataset(PPIDataset):
g = self.graphs[item]
else:
g = self._transform(self.graphs[item])
return g, g.ndata['feat'], g.ndata['label']
return g, g.ndata["feat"], g.ndata["label"]
......@@ -3,30 +3,39 @@ Datasets from "A Collection of Benchmark Datasets for
Systematic Evaluations of Machine Learning on
the Semantic Web"
"""
import os
from collections import OrderedDict
import itertools
import abc
import itertools
import os
import re
from collections import OrderedDict
import networkx as nx
import numpy as np
import dgl
import dgl.backend as F
from .dgl_dataset import DGLBuiltinDataset
from .utils import save_graphs, load_graphs, save_info, load_info, _get_dgl_url
from .utils import generate_mask_tensor, idx2mask
__all__ = ['AIFBDataset', 'MUTAGDataset', 'BGSDataset', 'AMDataset']
from .dgl_dataset import DGLBuiltinDataset
from .utils import (
_get_dgl_url,
generate_mask_tensor,
idx2mask,
load_graphs,
load_info,
save_graphs,
save_info,
)
__all__ = ["AIFBDataset", "MUTAGDataset", "BGSDataset", "AMDataset"]
# Dictionary for renaming reserved node/edge type names to the ones
# that are allowed by nn.Module.
RENAME_DICT = {
'type' : 'rdftype',
'rev-type' : 'rev-rdftype',
"type": "rdftype",
"rev-type": "rev-rdftype",
}
class Entity:
"""Class for entities
Parameters
......@@ -36,12 +45,14 @@ class Entity:
cls : str
Type of this entity
"""
def __init__(self, e_id, cls):
self.id = e_id
self.cls = cls
def __str__(self):
return '{}/{}'.format(self.cls, self.id)
return "{}/{}".format(self.cls, self.id)
class Relation:
"""Class for relations
......@@ -50,12 +61,14 @@ class Relation:
cls : str
Type of this relation
"""
def __init__(self, cls):
self.cls = cls
def __str__(self):
return str(self.cls)
class RDFGraphDataset(DGLBuiltinDataset):
"""Base graph dataset class from RDF tuples.
......@@ -101,22 +114,31 @@ class RDFGraphDataset(DGLBuiltinDataset):
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, name, url, predict_category,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
def __init__(
self,
name,
url,
predict_category,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
self._insert_reverse = insert_reverse
self._print_every = print_every
self._predict_category = predict_category
super(RDFGraphDataset, self).__init__(name, url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
super(RDFGraphDataset, self).__init__(
name,
url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
raw_tuples = self.load_raw_tuples(self.raw_path)
......@@ -135,17 +157,18 @@ class RDFGraphDataset(DGLBuiltinDataset):
Loaded rdf data
"""
import rdflib as rdf
raw_rdf_graphs = []
for _, filename in enumerate(os.listdir(root_path)):
fmt = None
if filename.endswith('nt'):
fmt = 'nt'
elif filename.endswith('n3'):
fmt = 'n3'
if filename.endswith("nt"):
fmt = "nt"
elif filename.endswith("n3"):
fmt = "n3"
if fmt is None:
continue
g = rdf.Graph()
print('Parsing file %s ...' % filename)
print("Parsing file %s ..." % filename)
g.parse(os.path.join(root_path, filename), format=fmt)
raw_rdf_graphs.append(g)
return itertools.chain(*raw_rdf_graphs)
......@@ -175,11 +198,16 @@ class RDFGraphDataset(DGLBuiltinDataset):
for i, (sbj, pred, obj) in enumerate(sorted_tuples):
if self.verbose and i % self._print_every == 0:
print('Processed %d tuples, found %d valid tuples.' % (i, len(src)))
print(
"Processed %d tuples, found %d valid tuples."
% (i, len(src))
)
sbjent = self.parse_entity(sbj)
rel = self.parse_relation(pred)
objent = self.parse_entity(obj)
processed = self.process_tuple((sbj, pred, obj), sbjent, rel, objent)
processed = self.process_tuple(
(sbj, pred, obj), sbjent, rel, objent
)
if processed is None:
# ignored
continue
......@@ -189,7 +217,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
relclsid = _get_id(rel_classes, rel.cls)
mg.add_edge(sbjent.cls, objent.cls, key=rel.cls)
if self._insert_reverse:
mg.add_edge(objent.cls, sbjent.cls, key='rev-%s' % rel.cls)
mg.add_edge(objent.cls, sbjent.cls, key="rev-%s" % rel.cls)
# instance graph
src_id = _get_id(entities, str(sbjent))
if len(entities) > len(ntid): # found new entity
......@@ -211,39 +239,47 @@ class RDFGraphDataset(DGLBuiltinDataset):
# add reverse edge with reverse relation
if self._insert_reverse:
if self.verbose:
print('Adding reverse edges ...')
print("Adding reverse edges ...")
newsrc = np.hstack([src, dst])
newdst = np.hstack([dst, src])
src = newsrc
dst = newdst
etid = np.hstack([etid, etid + len(etypes)])
etypes.extend(['rev-%s' % t for t in etypes])
etypes.extend(["rev-%s" % t for t in etypes])
hg = self.build_graph(mg, src, dst, ntid, etid, ntypes, etypes)
if self.verbose:
print('Load training/validation/testing split ...')
print("Load training/validation/testing split ...")
idmap = F.asnumpy(hg.nodes[self.predict_category].data[dgl.NID])
glb2lcl = {glbid : lclid for lclid, glbid in enumerate(idmap)}
glb2lcl = {glbid: lclid for lclid, glbid in enumerate(idmap)}
def findidfn(ent):
if ent not in entities:
return None
else:
return glb2lcl[entities[ent]]
self._hg = hg
train_idx, test_idx, labels, num_classes = self.load_data_split(findidfn, root_path)
train_mask = idx2mask(train_idx, self._hg.number_of_nodes(self.predict_category))
test_mask = idx2mask(test_idx, self._hg.number_of_nodes(self.predict_category))
labels = F.tensor(labels, F.data_type_dict['int64'])
self._hg = hg
train_idx, test_idx, labels, num_classes = self.load_data_split(
findidfn, root_path
)
train_mask = idx2mask(
train_idx, self._hg.number_of_nodes(self.predict_category)
)
test_mask = idx2mask(
test_idx, self._hg.number_of_nodes(self.predict_category)
)
labels = F.tensor(labels, F.data_type_dict["int64"])
train_mask = generate_mask_tensor(train_mask)
test_mask = generate_mask_tensor(test_mask)
self._hg.nodes[self.predict_category].data['train_mask'] = train_mask
self._hg.nodes[self.predict_category].data['test_mask'] = test_mask
self._hg.nodes[self.predict_category].data["train_mask"] = train_mask
self._hg.nodes[self.predict_category].data["test_mask"] = test_mask
# TODO(minjie): Deprecate 'labels', use 'label' for consistency.
self._hg.nodes[self.predict_category].data['labels'] = labels
self._hg.nodes[self.predict_category].data['label'] = labels
self._hg.nodes[self.predict_category].data["labels"] = labels
self._hg.nodes[self.predict_category].data["label"] = labels
self._num_classes = num_classes
def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):
......@@ -272,13 +308,13 @@ class RDFGraphDataset(DGLBuiltinDataset):
"""
# create homo graph
if self.verbose:
print('Creating one whole graph ...')
print("Creating one whole graph ...")
g = dgl.graph((src, dst))
g.ndata[dgl.NTYPE] = F.tensor(ntid)
g.edata[dgl.ETYPE] = F.tensor(etid)
if self.verbose:
print('Total #nodes:', g.number_of_nodes())
print('Total #edges:', g.number_of_edges())
print("Total #nodes:", g.number_of_nodes())
print("Total #edges:", g.number_of_edges())
# rename names such as 'type' so that they an be used as keys
# to nn.ModuleDict
......@@ -290,15 +326,12 @@ class RDFGraphDataset(DGLBuiltinDataset):
# convert to heterograph
if self.verbose:
print('Convert to heterograph ...')
hg = dgl.to_heterogeneous(g,
ntypes,
etypes,
metagraph=mg)
print("Convert to heterograph ...")
hg = dgl.to_heterogeneous(g, ntypes, etypes, metagraph=mg)
if self.verbose:
print('#Node types:', len(hg.ntypes))
print('#Canonical edge types:', len(hg.etypes))
print('#Unique edge type names:', len(set(hg.etypes)))
print("#Node types:", len(hg.ntypes))
print("#Canonical edge types:", len(hg.etypes))
print("#Unique edge type names:", len(set(hg.etypes)))
return hg
def load_data_split(self, ent2id, root_path):
......@@ -323,13 +356,18 @@ class RDFGraphDataset(DGLBuiltinDataset):
Number of classes
"""
label_dict = {}
labels = np.zeros((self._hg.number_of_nodes(self.predict_category),)) - 1
labels = (
np.zeros((self._hg.number_of_nodes(self.predict_category),)) - 1
)
train_idx = self.parse_idx_file(
os.path.join(root_path, 'trainingSet.tsv'),
ent2id, label_dict, labels)
os.path.join(root_path, "trainingSet.tsv"),
ent2id,
label_dict,
labels,
)
test_idx = self.parse_idx_file(
os.path.join(root_path, 'testSet.tsv'),
ent2id, label_dict, labels)
os.path.join(root_path, "testSet.tsv"), ent2id, label_dict, labels
)
train_idx = np.array(train_idx)
test_idx = np.array(test_idx)
labels = np.array(labels)
......@@ -356,16 +394,19 @@ class RDFGraphDataset(DGLBuiltinDataset):
Entity idss
"""
idx = []
with open(filename, 'r') as f:
with open(filename, "r") as f:
for i, line in enumerate(f):
if i == 0:
continue # first line is the header
sample, label = self.process_idx_file_line(line)
#person, _, label = line.strip().split('\t')
# person, _, label = line.strip().split('\t')
ent = self.parse_entity(sample)
entid = ent2id(str(ent))
if entid is None:
print('Warning: entity "%s" does not have any valid links associated. Ignored.' % str(ent))
print(
'Warning: entity "%s" does not have any valid links associated. Ignored.'
% str(ent)
)
else:
idx.append(entid)
lblid = _get_id(label_dict, label)
......@@ -374,46 +415,44 @@ class RDFGraphDataset(DGLBuiltinDataset):
def has_cache(self):
"""check if there is a processed data"""
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
if os.path.exists(graph_path) and \
os.path.exists(info_path):
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
if os.path.exists(graph_path) and os.path.exists(info_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
save_graphs(str(graph_path), self._hg)
save_info(str(info_path), {'num_classes': self.num_classes,
'predict_category': self.predict_category})
save_info(
str(info_path),
{
"num_classes": self.num_classes,
"predict_category": self.predict_category,
},
)
def load(self):
"""load the graph list and the labels from disk"""
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
graph_path = os.path.join(self.save_path, self.save_name + ".bin")
info_path = os.path.join(self.save_path, self.save_name + ".pkl")
graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path))
self._num_classes = info['num_classes']
self._predict_category = info['predict_category']
self._num_classes = info["num_classes"]
self._predict_category = info["predict_category"]
self._hg = graphs[0]
# For backward compatibility
if 'label' not in self._hg.nodes[self.predict_category].data:
self._hg.nodes[self.predict_category].data['label'] = \
self._hg.nodes[self.predict_category].data['labels']
if "label" not in self._hg.nodes[self.predict_category].data:
self._hg.nodes[self.predict_category].data[
"label"
] = self._hg.nodes[self.predict_category].data["labels"]
def __getitem__(self, idx):
r"""Gets the graph object
"""
r"""Gets the graph object"""
g = self._hg
if self._transform is not None:
g = self._transform(g)
......@@ -425,7 +464,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
@property
def save_name(self):
return self.name + '_dgl_graph'
return self.name + "_dgl_graph"
@property
def predict_category(self):
......@@ -504,6 +543,7 @@ class RDFGraphDataset(DGLBuiltinDataset):
"""
pass
def _get_id(dict, key):
id = dict.get(key, None)
if id is None:
......@@ -511,6 +551,7 @@ def _get_id(dict, key):
dict[key] = id
return id
class AIFBDataset(RDFGraphDataset):
r"""AIFB dataset for node classification task
......@@ -566,29 +607,40 @@ class AIFBDataset(RDFGraphDataset):
>>> label = g.nodes[category].data['label']
"""
entity_prefix = 'http://www.aifb.uni-karlsruhe.de/'
relation_prefix = 'http://swrc.ontoware.org/'
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
entity_prefix = "http://www.aifb.uni-karlsruhe.de/"
relation_prefix = "http://swrc.ontoware.org/"
def __init__(
self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
import rdflib as rdf
self.employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
self.affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
url = _get_dgl_url('dataset/rdf/aifb-hetero.zip')
name = 'aifb-hetero'
predict_category = 'Personen'
super(AIFBDataset, self).__init__(name, url, predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
self.employs = rdf.term.URIRef(
"http://swrc.ontoware.org/ontology#employs"
)
self.affiliation = rdf.term.URIRef(
"http://swrc.ontoware.org/ontology#affiliation"
)
url = _get_dgl_url("dataset/rdf/aifb-hetero.zip")
name = "aifb-hetero"
predict_category = "Personen"
super(AIFBDataset, self).__init__(
name,
url,
predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -621,13 +673,14 @@ class AIFBDataset(RDFGraphDataset):
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return Entity(e_id=str(term), cls="_Literal")
if isinstance(term, rdf.BNode):
return None
entstr = str(term)
if entstr.startswith(self.entity_prefix):
sp = entstr.split('/')
sp = entstr.split("/")
return Entity(e_id=sp[5], cls=sp[3])
else:
return None
......@@ -637,9 +690,9 @@ class AIFBDataset(RDFGraphDataset):
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
return Relation(cls=relstr.split('/')[3])
return Relation(cls=relstr.split("/")[3])
else:
relstr = relstr.split('/')[-1]
relstr = relstr.split("/")[-1]
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
......@@ -648,9 +701,10 @@ class AIFBDataset(RDFGraphDataset):
return (sbj, rel, obj)
def process_idx_file_line(self, line):
person, _, label = line.strip().split('\t')
person, _, label = line.strip().split("\t")
return person, label
class MUTAGDataset(RDFGraphDataset):
r"""MUTAG dataset for node classification task
......@@ -707,32 +761,47 @@ class MUTAGDataset(RDFGraphDataset):
d_entity = re.compile("d[0-9]")
bond_entity = re.compile("bond[0-9]")
entity_prefix = 'http://dl-learner.org/carcinogenesis#'
entity_prefix = "http://dl-learner.org/carcinogenesis#"
relation_prefix = entity_prefix
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
def __init__(
self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
import rdflib as rdf
self.is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
self.rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
self.rdf_subclassof = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#subClassOf")
self.rdf_domain = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#domain")
url = _get_dgl_url('dataset/rdf/mutag-hetero.zip')
name = 'mutag-hetero'
predict_category = 'd'
super(MUTAGDataset, self).__init__(name, url, predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
self.is_mutagenic = rdf.term.URIRef(
"http://dl-learner.org/carcinogenesis#isMutagenic"
)
self.rdf_type = rdf.term.URIRef(
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
)
self.rdf_subclassof = rdf.term.URIRef(
"http://www.w3.org/2000/01/rdf-schema#subClassOf"
)
self.rdf_domain = rdf.term.URIRef(
"http://www.w3.org/2000/01/rdf-schema#domain"
)
url = _get_dgl_url("dataset/rdf/mutag-hetero.zip")
name = "mutag-hetero"
predict_category = "d"
super(MUTAGDataset, self).__init__(
name,
url,
predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -765,17 +834,18 @@ class MUTAGDataset(RDFGraphDataset):
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return Entity(e_id=str(term), cls="_Literal")
elif isinstance(term, rdf.BNode):
return None
entstr = str(term)
if entstr.startswith(self.entity_prefix):
inst = entstr[len(self.entity_prefix):]
inst = entstr[len(self.entity_prefix) :]
if self.d_entity.match(inst):
cls = 'd'
cls = "d"
elif self.bond_entity.match(inst):
cls = 'bond'
cls = "bond"
else:
cls = None
return Entity(e_id=inst, cls=cls)
......@@ -787,20 +857,20 @@ class MUTAGDataset(RDFGraphDataset):
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
cls = relstr[len(self.relation_prefix):]
cls = relstr[len(self.relation_prefix) :]
return Relation(cls=cls)
else:
relstr = relstr.split('/')[-1]
relstr = relstr.split("/")[-1]
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
if sbj is None or rel is None or obj is None:
return None
if not raw_tuple[1].startswith('http://dl-learner.org/carcinogenesis#'):
obj.cls = 'SCHEMA'
if not raw_tuple[1].startswith("http://dl-learner.org/carcinogenesis#"):
obj.cls = "SCHEMA"
if sbj.cls is None:
sbj.cls = 'SCHEMA'
sbj.cls = "SCHEMA"
if obj.cls is None:
obj.cls = rel.cls
......@@ -809,9 +879,10 @@ class MUTAGDataset(RDFGraphDataset):
return (sbj, rel, obj)
def process_idx_file_line(self, line):
bond, _, label = line.strip().split('\t')
bond, _, label = line.strip().split("\t")
return bond, label
class BGSDataset(RDFGraphDataset):
r"""BGS dataset for node classification task
......@@ -869,29 +940,38 @@ class BGSDataset(RDFGraphDataset):
>>> label = g.nodes[category].data['label']
"""
entity_prefix = 'http://data.bgs.ac.uk/'
status_prefix = 'http://data.bgs.ac.uk/ref/CurrentStatus'
relation_prefix = 'http://data.bgs.ac.uk/ref'
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
entity_prefix = "http://data.bgs.ac.uk/"
status_prefix = "http://data.bgs.ac.uk/ref/CurrentStatus"
relation_prefix = "http://data.bgs.ac.uk/ref"
def __init__(
self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
import rdflib as rdf
url = _get_dgl_url('dataset/rdf/bgs-hetero.zip')
name = 'bgs-hetero'
predict_category = 'Lexicon/NamedRockUnit'
self.lith = rdf.term.URIRef("http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis")
super(BGSDataset, self).__init__(name, url, predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
url = _get_dgl_url("dataset/rdf/bgs-hetero.zip")
name = "bgs-hetero"
predict_category = "Lexicon/NamedRockUnit"
self.lith = rdf.term.URIRef(
"http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis"
)
super(BGSDataset, self).__init__(
name,
url,
predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -924,6 +1004,7 @@ class BGSDataset(RDFGraphDataset):
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return None
elif isinstance(term, rdf.BNode):
......@@ -932,11 +1013,11 @@ class BGSDataset(RDFGraphDataset):
if entstr.startswith(self.status_prefix):
return None
if entstr.startswith(self.entity_prefix):
sp = entstr.split('/')
sp = entstr.split("/")
if len(sp) != 7:
return None
# instance
cls = '%s/%s' % (sp[4], sp[5])
cls = "%s/%s" % (sp[4], sp[5])
inst = sp[6]
return Entity(e_id=inst, cls=cls)
else:
......@@ -947,14 +1028,14 @@ class BGSDataset(RDFGraphDataset):
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
sp = relstr.split('/')
sp = relstr.split("/")
if len(sp) < 6:
return None
assert len(sp) == 6, relstr
cls = '%s/%s' % (sp[4], sp[5])
cls = "%s/%s" % (sp[4], sp[5])
return Relation(cls=cls)
else:
relstr = relstr.replace('.', '_')
relstr = relstr.replace(".", "_")
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
......@@ -963,9 +1044,10 @@ class BGSDataset(RDFGraphDataset):
return (sbj, rel, obj)
def process_idx_file_line(self, line):
_, rock, label = line.strip().split('\t')
_, rock, label = line.strip().split("\t")
return rock, label
class AMDataset(RDFGraphDataset):
"""AM dataset. for node classification task
......@@ -1025,29 +1107,40 @@ class AMDataset(RDFGraphDataset):
>>> label = g.nodes[category].data['label']
"""
entity_prefix = 'http://purl.org/collections/nl/am/'
entity_prefix = "http://purl.org/collections/nl/am/"
relation_prefix = entity_prefix
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
def __init__(
self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
import rdflib as rdf
self.objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
self.material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
url = _get_dgl_url('dataset/rdf/am-hetero.zip')
name = 'am-hetero'
predict_category = 'proxy'
super(AMDataset, self).__init__(name, url, predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
self.objectCategory = rdf.term.URIRef(
"http://purl.org/collections/nl/am/objectCategory"
)
self.material = rdf.term.URIRef(
"http://purl.org/collections/nl/am/material"
)
url = _get_dgl_url("dataset/rdf/am-hetero.zip")
name = "am-hetero"
predict_category = "proxy"
super(AMDataset, self).__init__(
name,
url,
predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def __getitem__(self, idx):
r"""Gets the graph object
......@@ -1080,20 +1173,21 @@ class AMDataset(RDFGraphDataset):
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return None
elif isinstance(term, rdf.BNode):
return Entity(e_id=str(term), cls='_BNode')
return Entity(e_id=str(term), cls="_BNode")
entstr = str(term)
if entstr.startswith(self.entity_prefix):
sp = entstr.split('/')
sp = entstr.split("/")
assert len(sp) == 7, entstr
spp = sp[6].split('-')
spp = sp[6].split("-")
if len(spp) == 2:
# instance
cls, inst = spp
else:
cls = 'TYPE'
cls = "TYPE"
inst = spp
return Entity(e_id=inst, cls=cls)
else:
......@@ -1104,12 +1198,12 @@ class AMDataset(RDFGraphDataset):
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
sp = relstr.split('/')
sp = relstr.split("/")
assert len(sp) == 7, relstr
cls = sp[6]
return Relation(cls=cls)
else:
relstr = relstr.replace('.', '_')
relstr = relstr.replace(".", "_")
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
......@@ -1118,5 +1212,5 @@ class AMDataset(RDFGraphDataset):
return (sbj, rel, obj)
def process_idx_file_line(self, line):
proxy, _, label = line.strip().split('\t')
proxy, _, label = line.strip().split("\t")
return proxy, label
"""Dataset for stochastic block model."""
import math
import random
import os
import random
import numpy as np
import numpy.random as npr
import scipy as sp
from .dgl_dataset import DGLDataset
from ..convert import from_scipy
from .. import batch
from .utils import save_info, save_graphs, load_info, load_graphs
from ..convert import from_scipy
from .dgl_dataset import DGLDataset
from .utils import load_graphs, load_info, save_graphs, save_info
def sbm(n_blocks, block_size, p, q, rng=None):
""" (Symmetric) Stochastic Block Model
"""(Symmetric) Stochastic Block Model
Parameters
----------
......@@ -44,20 +44,27 @@ def sbm(n_blocks, block_size, p, q, rng=None):
for i in range(n_blocks):
for j in range(i, n_blocks):
density = p if i == j else q
block = sp.sparse.random(block_size, block_size, density,
random_state=rng, data_rvs=lambda n: np.ones(n))
block = sp.sparse.random(
block_size,
block_size,
density,
random_state=rng,
data_rvs=lambda n: np.ones(n),
)
rows.append(block.row + i * block_size)
cols.append(block.col + j * block_size)
rows = np.hstack(rows)
cols = np.hstack(cols)
a = sp.sparse.coo_matrix((np.ones(rows.shape[0]), (rows, cols)), shape=(n, n))
a = sp.sparse.coo_matrix(
(np.ones(rows.shape[0]), (rows, cols)), shape=(n, n)
)
adj = sp.sparse.triu(a) + sp.sparse.triu(a, 1).transpose()
return adj
class SBMMixtureDataset(DGLDataset):
r""" Symmetric Stochastic Block Model Mixture
r"""Symmetric Stochastic Block Model Mixture
Reference: Appendix C of `Supervised Community Detection with Hierarchical Graph Neural Networks <https://arxiv.org/abs/1705.08415>`_
......@@ -95,14 +102,17 @@ class SBMMixtureDataset(DGLDataset):
>>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader:
... # your code here
"""
def __init__(self,
n_graphs,
n_nodes,
n_communities,
k=2,
avg_deg=3,
pq='Appendix_C',
rng=None):
def __init__(
self,
n_graphs,
n_nodes,
n_communities,
k=2,
avg_deg=3,
pq="Appendix_C",
rng=None,
):
self._n_graphs = n_graphs
self._n_nodes = n_nodes
self._n_communities = n_communities
......@@ -112,60 +122,92 @@ class SBMMixtureDataset(DGLDataset):
self._avg_deg = avg_deg
self._pq = pq
self._rng = rng
super(SBMMixtureDataset, self).__init__(name='sbmmixture',
hash_key=(n_graphs, n_nodes, n_communities, k, avg_deg, pq, rng))
super(SBMMixtureDataset, self).__init__(
name="sbmmixture",
hash_key=(n_graphs, n_nodes, n_communities, k, avg_deg, pq, rng),
)
def process(self):
pq = self._pq
if type(pq) is list:
assert len(pq) == self._n_graphs
elif type(pq) is str:
generator = {'Appendix_C': self._appendix_c}[pq]
generator = {"Appendix_C": self._appendix_c}[pq]
pq = [generator() for _ in range(self._n_graphs)]
else:
raise RuntimeError()
self._graphs = [from_scipy(sbm(self._n_communities, self._block_size, *x)) for x in pq]
self._line_graphs = [g.line_graph(backtracking=False) for g in self._graphs]
self._graphs = [
from_scipy(sbm(self._n_communities, self._block_size, *x))
for x in pq
]
self._line_graphs = [
g.line_graph(backtracking=False) for g in self._graphs
]
in_degrees = lambda g: g.in_degrees().float()
self._graph_degrees = [in_degrees(g) for g in self._graphs]
self._line_graph_degrees = [in_degrees(lg) for lg in self._line_graphs]
self._pm_pds = list(zip(*[g.edges() for g in self._graphs]))[0]
def has_cache(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash))
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash))
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash))
return os.path.exists(graph_path) and \
os.path.exists(line_graph_path) and \
os.path.exists(info_path)
graph_path = os.path.join(
self.save_path, "graphs_{}.bin".format(self.hash)
)
line_graph_path = os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash)
)
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
return (
os.path.exists(graph_path)
and os.path.exists(line_graph_path)
and os.path.exists(info_path)
)
def save(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash))
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash))
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash))
graph_path = os.path.join(
self.save_path, "graphs_{}.bin".format(self.hash)
)
line_graph_path = os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash)
)
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
save_graphs(graph_path, self._graphs)
save_graphs(line_graph_path, self._line_graphs)
save_info(info_path, {'graph_degree': self._graph_degrees,
'line_graph_degree': self._line_graph_degrees,
'pm_pds': self._pm_pds})
save_info(
info_path,
{
"graph_degree": self._graph_degrees,
"line_graph_degree": self._line_graph_degrees,
"pm_pds": self._pm_pds,
},
)
def load(self):
graph_path = os.path.join(self.save_path, 'graphs_{}.bin'.format(self.hash))
line_graph_path = os.path.join(self.save_path, 'line_graphs_{}.bin'.format(self.hash))
info_path = os.path.join(self.save_path, 'info_{}.pkl'.format(self.hash))
graph_path = os.path.join(
self.save_path, "graphs_{}.bin".format(self.hash)
)
line_graph_path = os.path.join(
self.save_path, "line_graphs_{}.bin".format(self.hash)
)
info_path = os.path.join(
self.save_path, "info_{}.pkl".format(self.hash)
)
self._graphs, _ = load_graphs(graph_path)
self._line_graphs, _ = load_graphs(line_graph_path)
info = load_info(info_path)
self._graph_degrees = info['graph_degree']
self._line_graph_degrees = info['line_graph_degree']
self._pm_pds = info['pm_pds']
self._graph_degrees = info["graph_degree"]
self._line_graph_degrees = info["line_graph_degree"]
self._pm_pds = info["pm_pds"]
def __len__(self):
r"""Number of graphs in the dataset."""
return len(self._graphs)
def __getitem__(self, idx):
r""" Get one example by index
r"""Get one example by index
Parameters
----------
......@@ -185,8 +227,13 @@ class SBMMixtureDataset(DGLDataset):
pm_pd: numpy.ndarray
Edge indicator matrices Pm and Pd
"""
return self._graphs[idx], self._line_graphs[idx], \
self._graph_degrees[idx], self._line_graph_degrees[idx], self._pm_pds[idx]
return (
self._graphs[idx],
self._line_graphs[idx],
self._graph_degrees[idx],
self._line_graph_degrees[idx],
self._pm_pds[idx],
)
def _appendix_c(self):
q = npr.uniform(0, self._avg_deg - math.sqrt(self._avg_deg))
......@@ -197,7 +244,7 @@ class SBMMixtureDataset(DGLDataset):
return q, p
def collate_fn(self, x):
r""" The `collate` function for dataloader
r"""The `collate` function for dataloader
Parameters
----------
......@@ -233,7 +280,9 @@ class SBMMixtureDataset(DGLDataset):
lg_batch = batch.batch(lg)
degg_batch = np.concatenate(deg_g, axis=0)
deglg_batch = np.concatenate(deg_lg, axis=0)
pm_pd_batch = np.concatenate([x + i * self._n_nodes for i, x in enumerate(pm_pd)], axis=0)
pm_pd_batch = np.concatenate(
[x + i * self._n_nodes for i, x in enumerate(pm_pd)], axis=0
)
return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch
......
"""Synthetic graph datasets."""
import math
import networkx as nx
import numpy as np
import os
import pickle
import random
from .dgl_dataset import DGLBuiltinDataset
from .utils import save_graphs, load_graphs, _get_dgl_url, download
import networkx as nx
import numpy as np
from .. import backend as F
from ..batch import batch
from ..convert import graph
from ..transforms import reorder_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, download, load_graphs, save_graphs
class BAShapeDataset(DGLBuiltinDataset):
r"""BA-SHAPES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
......@@ -70,30 +72,37 @@ class BAShapeDataset(DGLBuiltinDataset):
>>> label = g.ndata['label']
>>> feat = g.ndata['feat']
"""
def __init__(self,
num_base_nodes=300,
num_base_edges_per_node=5,
num_motifs=80,
perturb_ratio=0.01,
seed=None,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
def __init__(
self,
num_base_nodes=300,
num_base_edges_per_node=5,
num_motifs=80,
perturb_ratio=0.01,
seed=None,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
self.num_base_nodes = num_base_nodes
self.num_base_edges_per_node = num_base_edges_per_node
self.num_motifs = num_motifs
self.perturb_ratio = perturb_ratio
self.seed = seed
super(BAShapeDataset, self).__init__(name='BA-SHAPES',
url=None,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
super(BAShapeDataset, self).__init__(
name="BA-SHAPES",
url=None,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
g = nx.barabasi_albert_graph(self.num_base_nodes, self.num_base_edges_per_node, self.seed)
g = nx.barabasi_albert_graph(
self.num_base_nodes, self.num_base_edges_per_node, self.seed
)
edges = list(g.edges())
src, dst = map(list, zip(*edges))
n = self.num_base_nodes
......@@ -111,7 +120,7 @@ class BAShapeDataset(DGLBuiltinDataset):
(n + 2, n + 3),
(n + 3, n),
(n + 4, n),
(n + 4, n + 1)
(n + 4, n + 1),
]
motif_src, motif_dst = map(list, zip(*motif_edges))
src.extend(motif_src)
......@@ -132,8 +141,9 @@ class BAShapeDataset(DGLBuiltinDataset):
# Perturb the graph by adding non-self-loop random edges
num_real_edges = g.num_edges()
max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges
assert self.perturb_ratio <= max_ratio, \
'perturb_ratio cannot exceed {:.4f}'.format(max_ratio)
assert (
self.perturb_ratio <= max_ratio
), "perturb_ratio cannot exceed {:.4f}".format(max_ratio)
num_random_edges = int(num_real_edges * self.perturb_ratio)
if self.seed is not None:
......@@ -146,14 +156,20 @@ class BAShapeDataset(DGLBuiltinDataset):
break
g.add_edges(u, v)
g.ndata['label'] = F.tensor(node_labels, F.int64)
g.ndata['feat'] = F.ones((n, 1), F.float32, F.cpu())
g.ndata["label"] = F.tensor(node_labels, F.int64)
g.ndata["feat"] = F.ones((n, 1), F.float32, F.cpu())
self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
@property
def graph_path(self):
return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name))
return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.name)
)
def save(self):
save_graphs(str(self.graph_path), self._graph)
......@@ -179,6 +195,7 @@ class BAShapeDataset(DGLBuiltinDataset):
def num_classes(self):
return 4
class BACommunityDataset(DGLBuiltinDataset):
r"""BA-COMMUNITY dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
<https://arxiv.org/abs/1903.03894>`__
......@@ -241,29 +258,34 @@ class BACommunityDataset(DGLBuiltinDataset):
>>> label = g.ndata['label']
>>> feat = g.ndata['feat']
"""
def __init__(self,
num_base_nodes=300,
num_base_edges_per_node=4,
num_motifs=80,
perturb_ratio=0.01,
num_inter_edges=350,
seed=None,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
def __init__(
self,
num_base_nodes=300,
num_base_edges_per_node=4,
num_motifs=80,
perturb_ratio=0.01,
num_inter_edges=350,
seed=None,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
self.num_base_nodes = num_base_nodes
self.num_base_edges_per_node = num_base_edges_per_node
self.num_motifs = num_motifs
self.perturb_ratio = perturb_ratio
self.num_inter_edges = num_inter_edges
self.seed = seed
super(BACommunityDataset, self).__init__(name='BA-COMMUNITY',
url=None,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
super(BACommunityDataset, self).__init__(
name="BA-COMMUNITY",
url=None,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
if self.seed is not None:
......@@ -271,47 +293,65 @@ class BACommunityDataset(DGLBuiltinDataset):
np.random.seed(self.seed)
# Construct two BA-SHAPES graphs
g1 = BAShapeDataset(self.num_base_nodes,
self.num_base_edges_per_node,
self.num_motifs,
self.perturb_ratio,
force_reload=True,
verbose=False)[0]
g2 = BAShapeDataset(self.num_base_nodes,
self.num_base_edges_per_node,
self.num_motifs,
self.perturb_ratio,
force_reload=True,
verbose=False)[0]
g1 = BAShapeDataset(
self.num_base_nodes,
self.num_base_edges_per_node,
self.num_motifs,
self.perturb_ratio,
force_reload=True,
verbose=False,
)[0]
g2 = BAShapeDataset(
self.num_base_nodes,
self.num_base_edges_per_node,
self.num_motifs,
self.perturb_ratio,
force_reload=True,
verbose=False,
)[0]
# Join them and randomly add edges between them
g = batch([g1, g2])
num_nodes = g.num_nodes() // 2
src = np.random.randint(0, num_nodes, (self.num_inter_edges,))
dst = np.random.randint(num_nodes, 2 * num_nodes, (self.num_inter_edges,))
dst = np.random.randint(
num_nodes, 2 * num_nodes, (self.num_inter_edges,)
)
src = F.astype(F.zerocopy_from_numpy(src), g.idtype)
dst = F.astype(F.zerocopy_from_numpy(dst), g.idtype)
g.add_edges(src, dst)
g.ndata['label'] = F.cat([g1.ndata['label'], g2.ndata['label'] + 4], dim=0)
g.ndata["label"] = F.cat(
[g1.ndata["label"], g2.ndata["label"] + 4], dim=0
)
# feature generation
random_mu = [0.0] * 8
random_sigma = [1.0] * 8
mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma)
mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array(
[0.5] * 2 + random_sigma
)
feat1 = np.random.multivariate_normal(mu_1, np.diag(sigma_1), num_nodes)
mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma)
mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array(
[0.5] * 2 + random_sigma
)
feat2 = np.random.multivariate_normal(mu_2, np.diag(sigma_2), num_nodes)
feat = np.concatenate([feat1, feat2])
g.ndata['feat'] = F.zerocopy_from_numpy(feat)
g.ndata["feat"] = F.zerocopy_from_numpy(feat)
self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
@property
def graph_path(self):
return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name))
return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.name)
)
def save(self):
save_graphs(str(self.graph_path), self._graph)
......@@ -337,6 +377,7 @@ class BACommunityDataset(DGLBuiltinDataset):
def num_classes(self):
return 8
class TreeCycleDataset(DGLBuiltinDataset):
r"""TREE-CYCLES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
<https://arxiv.org/abs/1903.03894>`__
......@@ -392,27 +433,32 @@ class TreeCycleDataset(DGLBuiltinDataset):
>>> label = g.ndata['label']
>>> feat = g.ndata['feat']
"""
def __init__(self,
tree_height=8,
num_motifs=60,
cycle_size=6,
perturb_ratio=0.01,
seed=None,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
def __init__(
self,
tree_height=8,
num_motifs=60,
cycle_size=6,
perturb_ratio=0.01,
seed=None,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
self.tree_height = tree_height
self.num_motifs = num_motifs
self.cycle_size = cycle_size
self.perturb_ratio = perturb_ratio
self.seed = seed
super(TreeCycleDataset, self).__init__(name='TREE-CYCLES',
url=None,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
super(TreeCycleDataset, self).__init__(
name="TREE-CYCLES",
url=None,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
if self.seed is not None:
......@@ -457,8 +503,9 @@ class TreeCycleDataset(DGLBuiltinDataset):
# Perturb the graph by adding non-self-loop random edges
num_real_edges = g.num_edges()
max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges
assert self.perturb_ratio <= max_ratio, \
'perturb_ratio cannot exceed {:.4f}'.format(max_ratio)
assert (
self.perturb_ratio <= max_ratio
), "perturb_ratio cannot exceed {:.4f}".format(max_ratio)
num_random_edges = int(num_real_edges * self.perturb_ratio)
for _ in range(num_random_edges):
......@@ -469,14 +516,20 @@ class TreeCycleDataset(DGLBuiltinDataset):
break
g.add_edges(u, v)
g.ndata['label'] = F.tensor(node_labels, F.int64)
g.ndata['feat'] = F.ones((n, 1), F.float32, F.cpu())
g.ndata["label"] = F.tensor(node_labels, F.int64)
g.ndata["feat"] = F.ones((n, 1), F.float32, F.cpu())
self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
@property
def graph_path(self):
return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name))
return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.name)
)
def save(self):
save_graphs(str(self.graph_path), self._graph)
......@@ -502,6 +555,7 @@ class TreeCycleDataset(DGLBuiltinDataset):
def num_classes(self):
return 2
class TreeGridDataset(DGLBuiltinDataset):
r"""TREE-GRIDS dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks
<https://arxiv.org/abs/1903.03894>`__
......@@ -557,27 +611,32 @@ class TreeGridDataset(DGLBuiltinDataset):
>>> label = g.ndata['label']
>>> feat = g.ndata['feat']
"""
def __init__(self,
tree_height=8,
num_motifs=80,
grid_size=3,
perturb_ratio=0.1,
seed=None,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
def __init__(
self,
tree_height=8,
num_motifs=80,
grid_size=3,
perturb_ratio=0.1,
seed=None,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None,
):
self.tree_height = tree_height
self.num_motifs = num_motifs
self.grid_size = grid_size
self.perturb_ratio = perturb_ratio
self.seed = seed
super(TreeGridDataset, self).__init__(name='TREE-GRIDS',
url=None,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
super(TreeGridDataset, self).__init__(
name="TREE-GRIDS",
url=None,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
if self.seed is not None:
......@@ -619,8 +678,9 @@ class TreeGridDataset(DGLBuiltinDataset):
# Perturb the graph by adding non-self-loop random edges
num_real_edges = g.num_edges()
max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges
assert self.perturb_ratio <= max_ratio, \
'perturb_ratio cannot exceed {:.4f}'.format(max_ratio)
assert (
self.perturb_ratio <= max_ratio
), "perturb_ratio cannot exceed {:.4f}".format(max_ratio)
num_random_edges = int(num_real_edges * self.perturb_ratio)
for _ in range(num_random_edges):
......@@ -631,14 +691,20 @@ class TreeGridDataset(DGLBuiltinDataset):
break
g.add_edges(u, v)
g.ndata['label'] = F.tensor(node_labels, F.int64)
g.ndata['feat'] = F.ones((n, 1), F.float32, F.cpu())
g.ndata["label"] = F.tensor(node_labels, F.int64)
g.ndata["feat"] = F.ones((n, 1), F.float32, F.cpu())
self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
@property
def graph_path(self):
return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name))
return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.name)
)
def save(self):
save_graphs(str(self.graph_path), self._graph)
......@@ -664,6 +730,7 @@ class TreeGridDataset(DGLBuiltinDataset):
def num_classes(self):
return 2
class BA2MotifDataset(DGLBuiltinDataset):
r"""BA-2motifs dataset from `Parameterized Explainer for Graph Neural Network
<https://arxiv.org/abs/2011.04573>`__
......@@ -705,26 +772,27 @@ class BA2MotifDataset(DGLBuiltinDataset):
>>> g, label = dataset[0]
>>> feat = g.ndata['feat']
"""
def __init__(self,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
super(BA2MotifDataset, self).__init__(name='BA-2motifs',
url=_get_dgl_url('dataset/BA-2motif.pkl'),
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
def __init__(
self, raw_dir=None, force_reload=False, verbose=True, transform=None
):
super(BA2MotifDataset, self).__init__(
name="BA-2motifs",
url=_get_dgl_url("dataset/BA-2motif.pkl"),
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def download(self):
r""" Automatically download data."""
file_path = os.path.join(self.raw_dir, self.name + '.pkl')
r"""Automatically download data."""
file_path = os.path.join(self.raw_dir, self.name + ".pkl")
download(self.url, path=file_path)
def process(self):
file_path = os.path.join(self.raw_dir, self.name + '.pkl')
with open(file_path, 'rb') as f:
file_path = os.path.join(self.raw_dir, self.name + ".pkl")
with open(file_path, "rb") as f:
adjs, features, labels = pickle.load(f)
self.graphs = []
......@@ -732,15 +800,17 @@ class BA2MotifDataset(DGLBuiltinDataset):
for i in range(len(adjs)):
g = graph(adjs[i].nonzero())
g.ndata['feat'] = F.zerocopy_from_numpy(features[i])
g.ndata["feat"] = F.zerocopy_from_numpy(features[i])
self.graphs.append(g)
@property
def graph_path(self):
return os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.name))
return os.path.join(
self.save_path, "{}_dgl_graph.bin".format(self.name)
)
def save(self):
label_dict = {'labels': self.labels}
label_dict = {"labels": self.labels}
save_graphs(str(self.graph_path), self.graphs, label_dict)
def has_cache(self):
......@@ -748,7 +818,7 @@ class BA2MotifDataset(DGLBuiltinDataset):
def load(self):
self.graphs, label_dict = load_graphs(str(self.graph_path))
self.labels = label_dict['labels']
self.labels = label_dict["labels"]
def __getitem__(self, idx):
g = self.graphs[idx]
......
"""For Tensor Serialization"""
from __future__ import absolute_import
from ..ndarray import NDArray
from .._ffi.function import _init_api
from .. import backend as F
from .._ffi.function import _init_api
from ..ndarray import NDArray
__all__ = ['save_tensors', "load_tensors"]
__all__ = ["save_tensors", "load_tensors"]
_init_api("dgl.data.tensor_serialize")
......@@ -12,11 +13,11 @@ _init_api("dgl.data.tensor_serialize")
def save_tensors(filename, tensor_dict):
"""
Save dict of tensors to file
Parameters
----------
filename : str
File name to store dict of tensors.
File name to store dict of tensors.
tensor_dict: dict of dgl NDArray or backend tensor
Python dict using string as key and tensor as value
......@@ -36,19 +37,20 @@ def save_tensors(filename, tensor_dict):
nd_dict[key] = value
else:
raise Exception(
"Dict value has to be backend tensor or dgl ndarray")
"Dict value has to be backend tensor or dgl ndarray"
)
return _CAPI_SaveNDArrayDict(filename, nd_dict, is_empty_dict)
def load_tensors(filename, return_dgl_ndarray=False):
"""
load dict of tensors from file
Parameters
----------
filename : str
File name to load dict of tensors.
File name to load dict of tensors.
return_dgl_ndarray: bool
Whether return dict of dgl NDArrays or backend tensors
......
"""Dataset utilities."""
from __future__ import absolute_import
import errno
import hashlib
import os
import pickle
import sys
import hashlib
import warnings
import requests
import pickle
import errno
import numpy as np
import pickle
import errno
from .graph_serialize import save_graphs, load_graphs, load_labels
from .tensor_serialize import save_tensors, load_tensors
import numpy as np
import requests
from .. import backend as F
__all__ = ['loadtxt','download', 'check_sha1', 'extract_archive',
'get_download_dir', 'Subset', 'split_dataset', 'save_graphs',
'load_graphs', 'load_labels', 'save_tensors', 'load_tensors',
'add_nodepred_split',
from .graph_serialize import load_graphs, load_labels, save_graphs
from .tensor_serialize import load_tensors, save_tensors
__all__ = [
"loadtxt",
"download",
"check_sha1",
"extract_archive",
"get_download_dir",
"Subset",
"split_dataset",
"save_graphs",
"load_graphs",
"load_labels",
"save_tensors",
"load_tensors",
"add_nodepred_split",
]
def loadtxt(path, delimiter, dtype=None):
try:
import pandas as pd
df = pd.read_csv(path, delimiter=delimiter, header=None)
return df.values
except ImportError:
warnings.warn("Pandas is not installed, now using numpy.loadtxt to load data, "
"which could be extremely slow. Accelerate by installing pandas")
warnings.warn(
"Pandas is not installed, now using numpy.loadtxt to load data, "
"which could be extremely slow. Accelerate by installing pandas"
)
return np.loadtxt(path, delimiter=delimiter)
def _get_dgl_url(file_url):
"""Get DGL online url for download."""
dgl_repo_url = 'https://data.dgl.ai/'
repo_url = os.environ.get('DGL_REPO', dgl_repo_url)
if repo_url[-1] != '/':
repo_url = repo_url + '/'
dgl_repo_url = "https://data.dgl.ai/"
repo_url = os.environ.get("DGL_REPO", dgl_repo_url)
if repo_url[-1] != "/":
repo_url = repo_url + "/"
return repo_url + file_url
......@@ -70,23 +82,35 @@ def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
Subsets for training, validation and test.
"""
from itertools import accumulate
if frac_list is None:
frac_list = [0.8, 0.1, 0.1]
frac_list = np.asarray(frac_list)
assert np.allclose(np.sum(frac_list), 1.), \
'Expect frac_list sum to 1, got {:.4f}'.format(np.sum(frac_list))
assert np.allclose(
np.sum(frac_list), 1.0
), "Expect frac_list sum to 1, got {:.4f}".format(np.sum(frac_list))
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
if shuffle:
indices = np.random.RandomState(
seed=random_state).permutation(num_data)
indices = np.random.RandomState(seed=random_state).permutation(num_data)
else:
indices = np.arange(num_data)
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]
def download(url, path=None, overwrite=True, sha1_hash=None, retries=5, verify_ssl=True, log=True):
return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(accumulate(lengths), lengths)
]
def download(
url,
path=None,
overwrite=True,
sha1_hash=None,
retries=5,
verify_ssl=True,
log=True,
):
"""Download a given URL.
Codes borrowed from mxnet/gluon/utils.py
......@@ -117,45 +141,54 @@ def download(url, path=None, overwrite=True, sha1_hash=None, retries=5, verify_s
The file path of the downloaded file.
"""
if path is None:
fname = url.split('/')[-1]
fname = url.split("/")[-1]
# Empty filenames are invalid
assert fname, 'Can\'t construct file-name from this URL. ' \
'Please set the `path` option manually.'
assert fname, (
"Can't construct file-name from this URL. "
"Please set the `path` option manually."
)
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split('/')[-1])
fname = os.path.join(path, url.split("/")[-1])
else:
fname = path
assert retries >= 0, "Number of retries should be at least 0"
if not verify_ssl:
warnings.warn(
'Unverified HTTPS request is being made (verify_ssl=False). '
'Adding certificate verification is strongly advised.')
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
"Unverified HTTPS request is being made (verify_ssl=False). "
"Adding certificate verification is strongly advised."
)
if (
overwrite
or not os.path.exists(fname)
or (sha1_hash and not check_sha1(fname, sha1_hash))
):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
while retries+1 > 0:
while retries + 1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
if log:
print('Downloading %s from %s...' % (fname, url))
print("Downloading %s from %s..." % (fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s" % url)
with open(fname, 'wb') as f:
with open(fname, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match.'
' The repo may be outdated or download may be incomplete. '
'If the "repo_url" is overridden, consider switching to '
'the default repo.'.format(fname))
raise UserWarning(
"File {} is downloaded but the content hash does not match."
" The repo may be outdated or download may be incomplete. "
'If the "repo_url" is overridden, consider switching to '
"the default repo.".format(fname)
)
break
except Exception as e:
retries -= 1
......@@ -163,8 +196,11 @@ def download(url, path=None, overwrite=True, sha1_hash=None, retries=5, verify_s
raise e
else:
if log:
print("download failed, retrying, {} attempt{} left"
.format(retries, 's' if retries > 1 else ''))
print(
"download failed, retrying, {} attempt{} left".format(
retries, "s" if retries > 1 else ""
)
)
return fname
......@@ -187,7 +223,7 @@ def check_sha1(filename, sha1_hash):
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
with open(filename, "rb") as f:
while True:
data = f.read(1048576)
if not data:
......@@ -212,24 +248,31 @@ def extract_archive(file, target_dir, overwrite=False):
"""
if os.path.exists(target_dir) and not overwrite:
return
print('Extracting file to {}'.format(target_dir))
if file.endswith('.tar.gz') or file.endswith('.tar') or file.endswith('.tgz'):
print("Extracting file to {}".format(target_dir))
if (
file.endswith(".tar.gz")
or file.endswith(".tar")
or file.endswith(".tgz")
):
import tarfile
with tarfile.open(file, 'r') as archive:
with tarfile.open(file, "r") as archive:
archive.extractall(path=target_dir)
elif file.endswith('.gz'):
elif file.endswith(".gz"):
import gzip
import shutil
with gzip.open(file, 'rb') as f_in:
with gzip.open(file, "rb") as f_in:
target_file = os.path.join(target_dir, os.path.basename(file)[:-3])
with open(target_file, 'wb') as f_out:
with open(target_file, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
elif file.endswith('.zip'):
elif file.endswith(".zip"):
import zipfile
with zipfile.ZipFile(file, 'r') as archive:
with zipfile.ZipFile(file, "r") as archive:
archive.extractall(path=target_dir)
else:
raise Exception('Unrecognized file type: ' + file)
raise Exception("Unrecognized file type: " + file)
def get_download_dir():
......@@ -240,12 +283,13 @@ def get_download_dir():
dirname : str
Path to the download directory
"""
default_dir = os.path.join(os.path.expanduser('~'), '.dgl')
dirname = os.environ.get('DGL_DOWNLOAD_DIR', default_dir)
default_dir = os.path.join(os.path.expanduser("~"), ".dgl")
dirname = os.environ.get("DGL_DOWNLOAD_DIR", default_dir)
if not os.path.exists(dirname):
os.makedirs(dirname)
return dirname
def makedirs(path):
try:
os.makedirs(os.path.expanduser(os.path.normpath(path)))
......@@ -253,8 +297,9 @@ def makedirs(path):
if e.errno != errno.EEXIST and os.path.isdir(path):
raise e
def save_info(path, info):
""" Save dataset related information into disk.
"""Save dataset related information into disk.
Parameters
----------
......@@ -263,12 +308,12 @@ def save_info(path, info):
info : dict
A python dict storing information to save on disk.
"""
with open(path, "wb" ) as pf:
with open(path, "wb") as pf:
pickle.dump(info, pf)
def load_info(path):
""" Load dataset related information from disk.
"""Load dataset related information from disk.
Parameters
----------
......@@ -284,16 +329,28 @@ def load_info(path):
info = pickle.load(pf)
return info
def deprecate_property(old, new):
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
warnings.warn(
"Property {} will be deprecated, please use {} instead.".format(
old, new
)
)
def deprecate_function(old, new):
warnings.warn('Function {} will be deprecated, please use {} instead.'.format(old, new))
warnings.warn(
"Function {} will be deprecated, please use {} instead.".format(
old, new
)
)
def deprecate_class(old, new):
warnings.warn('Class {} will be deprecated, please use {} instead.'.format(old, new))
warnings.warn(
"Class {} will be deprecated, please use {} instead.".format(old, new)
)
def idx2mask(idx, len):
"""Create mask."""
......@@ -301,6 +358,7 @@ def idx2mask(idx, len):
mask[idx] = 1
return mask
def generate_mask_tensor(mask):
"""Generate mask tensor according to different backend
For torch and tensorflow, it will create a bool tensor
......@@ -310,12 +368,14 @@ def generate_mask_tensor(mask):
mask: numpy ndarray
input mask tensor
"""
assert isinstance(mask, np.ndarray), "input for generate_mask_tensor" \
"should be an numpy ndarray"
if F.backend_name == 'mxnet':
return F.tensor(mask, dtype=F.data_type_dict['float32'])
assert isinstance(mask, np.ndarray), (
"input for generate_mask_tensor" "should be an numpy ndarray"
)
if F.backend_name == "mxnet":
return F.tensor(mask, dtype=F.data_type_dict["float32"])
else:
return F.tensor(mask, dtype=F.data_type_dict['bool'])
return F.tensor(mask, dtype=F.data_type_dict["bool"])
class Subset(object):
"""Subset of a dataset at specified indices
......@@ -354,6 +414,7 @@ class Subset(object):
"""
return len(self.indices)
def add_nodepred_split(dataset, ratio, ntype=None):
"""Split the given dataset into training, validation and test sets for
transductive node predction task.
......@@ -384,16 +445,24 @@ def add_nodepred_split(dataset, ratio, ntype=None):
True
"""
if len(ratio) != 3:
raise ValueError(f'Split ratio must be a float triplet but got {ratio}.')
raise ValueError(
f"Split ratio must be a float triplet but got {ratio}."
)
for i in range(len(dataset)):
g = dataset[i]
n = g.num_nodes(ntype)
idx = np.arange(0, n)
np.random.shuffle(idx)
n_train, n_val, n_test = int(n * ratio[0]), int(n * ratio[1]), int(n * ratio[2])
n_train, n_val, n_test = (
int(n * ratio[0]),
int(n * ratio[1]),
int(n * ratio[2]),
)
train_mask = generate_mask_tensor(idx2mask(idx[:n_train], n))
val_mask = generate_mask_tensor(idx2mask(idx[n_train:n_train + n_val], n))
test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val:], n))
g.nodes[ntype].data['train_mask'] = train_mask
g.nodes[ntype].data['val_mask'] = val_mask
g.nodes[ntype].data['test_mask'] = test_mask
val_mask = generate_mask_tensor(
idx2mask(idx[n_train : n_train + n_val], n)
)
test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val :], n))
g.nodes[ntype].data["train_mask"] = train_mask
g.nodes[ntype].data["val_mask"] = val_mask
g.nodes[ntype].data["test_mask"] = test_mask
"""Wiki-CS Dataset"""
import itertools
import os
import json
import os
import numpy as np
from .. import backend as F
from ..convert import graph
from ..transforms import reorder_graph, to_bidirected
from .dgl_dataset import DGLBuiltinDataset
from ..transforms import to_bidirected, reorder_graph
from .utils import generate_mask_tensor, load_graphs, save_graphs, _get_dgl_url
from .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs
class WikiCSDataset(DGLBuiltinDataset):
......@@ -73,55 +75,64 @@ class WikiCSDataset(DGLBuiltinDataset):
(11701,)
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
_url = _get_dgl_url('dataset/wiki_cs.zip')
super(WikiCSDataset, self).__init__(name='wiki_cs',
raw_dir=raw_dir,
url=_url,
force_reload=force_reload,
verbose=verbose,
transform=transform)
def __init__(
self, raw_dir=None, force_reload=False, verbose=False, transform=None
):
_url = _get_dgl_url("dataset/wiki_cs.zip")
super(WikiCSDataset, self).__init__(
name="wiki_cs",
raw_dir=raw_dir,
url=_url,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
"""process raw data to graph, labels and masks"""
with open(os.path.join(self.raw_path, 'data.json')) as f:
with open(os.path.join(self.raw_path, "data.json")) as f:
data = json.load(f)
features = F.tensor(np.array(data['features']), dtype=F.float32)
labels = F.tensor(np.array(data['labels']), dtype=F.int64)
features = F.tensor(np.array(data["features"]), dtype=F.float32)
labels = F.tensor(np.array(data["labels"]), dtype=F.int64)
train_masks = np.array(data['train_masks'], dtype=bool).T
val_masks = np.array(data['val_masks'], dtype=bool).T
stopping_masks = np.array(data['stopping_masks'], dtype=bool).T
test_mask = np.array(data['test_mask'], dtype=bool)
train_masks = np.array(data["train_masks"], dtype=bool).T
val_masks = np.array(data["val_masks"], dtype=bool).T
stopping_masks = np.array(data["stopping_masks"], dtype=bool).T
test_mask = np.array(data["test_mask"], dtype=bool)
edges = [[(i, j) for j in js] for i, js in enumerate(data['links'])]
edges = [[(i, j) for j in js] for i, js in enumerate(data["links"])]
edges = np.array(list(itertools.chain(*edges)))
src, dst = edges[:, 0], edges[:, 1]
g = graph((src, dst))
g = to_bidirected(g)
g.ndata['feat'] = features
g.ndata['label'] = labels
g.ndata['train_mask'] = generate_mask_tensor(train_masks)
g.ndata['val_mask'] = generate_mask_tensor(val_masks)
g.ndata['stopping_mask'] = generate_mask_tensor(stopping_masks)
g.ndata['test_mask'] = generate_mask_tensor(test_mask)
g.ndata["feat"] = features
g.ndata["label"] = labels
g.ndata["train_mask"] = generate_mask_tensor(train_masks)
g.ndata["val_mask"] = generate_mask_tensor(val_masks)
g.ndata["stopping_mask"] = generate_mask_tensor(stopping_masks)
g.ndata["test_mask"] = generate_mask_tensor(test_mask)
g = reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
g = reorder_graph(
g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
self._graph = g
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graph_path = os.path.join(self.save_path, "dgl_graph.bin")
return os.path.exists(graph_path)
def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(graph_path, self._graph)
def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graph_path = os.path.join(self.save_path, "dgl_graph.bin")
g, _ = load_graphs(graph_path)
self._graph = g[0]
......@@ -134,7 +145,7 @@ class WikiCSDataset(DGLBuiltinDataset):
return 1
def __getitem__(self, idx):
r""" Get graph object
r"""Get graph object
Parameters
----------
......
"""Yelp Dataset"""
import os
import json
import os
import numpy as np
import scipy.sparse as sp
from .. import backend as F
from ..convert import from_scipy
from ..transforms import reorder_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import generate_mask_tensor, load_graphs, save_graphs, _get_dgl_url
from .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs
class YelpDataset(DGLBuiltinDataset):
r"""Yelp dataset for node classification from `GraphSAINT: Graph Sampling Based Inductive
Learning Method <https://arxiv.org/abs/1907.04931>`_
The task of this dataset is categorizing types of businesses based on customer reviewers and
friendship.
Yelp dataset statistics:
- Nodes: 716,847
- Edges: 13,954,819
- Number of classes: 100 (Multi-class)
- Node feature size: 300
Parameters
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset.
Default: False
verbose : bool
Whether to print out progress information.
Default: False
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
reorder : bool
Whether to reorder the graph using :func:`~dgl.reorder_graph`.
Default: False.
Attributes
----------
num_classes : int
Number of node classes
Examples
--------
>>> dataset = YelpDataset()
>>> dataset.num_classes
100
>>> g = dataset[0]
>>> # get node feature
>>> feat = g.ndata['feat']
>>> # get node labels
>>> labels = g.ndata['label']
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None,
reorder=False):
_url = _get_dgl_url('dataset/yelp.zip')
Learning Method <https://arxiv.org/abs/1907.04931>`_
The task of this dataset is categorizing types of businesses based on customer reviewers and
friendship.
Yelp dataset statistics:
- Nodes: 716,847
- Edges: 13,954,819
- Number of classes: 100 (Multi-class)
- Node feature size: 300
Parameters
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset.
Default: False
verbose : bool
Whether to print out progress information.
Default: False
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
reorder : bool
Whether to reorder the graph using :func:`~dgl.reorder_graph`.
Default: False.
Attributes
----------
num_classes : int
Number of node classes
Examples
--------
>>> dataset = YelpDataset()
>>> dataset.num_classes
100
>>> g = dataset[0]
>>> # get node feature
>>> feat = g.ndata['feat']
>>> # get node labels
>>> labels = g.ndata['label']
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
"""
def __init__(
self,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
reorder=False,
):
_url = _get_dgl_url("dataset/yelp.zip")
self._reorder = reorder
super(YelpDataset, self).__init__(name='yelp',
raw_dir=raw_dir,
url=_url,
force_reload=force_reload,
verbose=verbose,
transform=transform)
super(YelpDataset, self).__init__(
name="yelp",
raw_dir=raw_dir,
url=_url,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
"""process raw data to graph, labels and masks"""
coo_adj = sp.load_npz(os.path.join(self.raw_path, "adj_full.npz"))
g = from_scipy(coo_adj)
features = np.load(os.path.join(self.raw_path, 'feats.npy'))
features = np.load(os.path.join(self.raw_path, "feats.npy"))
features = F.tensor(features, dtype=F.float32)
y = [-1] * features.shape[0]
with open(os.path.join(self.raw_path, 'class_map.json')) as f:
with open(os.path.join(self.raw_path, "class_map.json")) as f:
class_map = json.load(f)
for key, item in class_map.items():
y[int(key)] = item
labels = F.tensor(np.array(y), dtype=F.int64)
with open(os.path.join(self.raw_path, 'role.json')) as f:
with open(os.path.join(self.raw_path, "role.json")) as f:
role = json.load(f)
train_mask = np.zeros(features.shape[0], dtype=bool)
train_mask[role['tr']] = True
train_mask[role["tr"]] = True
val_mask = np.zeros(features.shape[0], dtype=bool)
val_mask[role['va']] = True
val_mask[role["va"]] = True
test_mask = np.zeros(features.shape[0], dtype=bool)
test_mask[role['te']] = True
test_mask[role["te"]] = True
g.ndata['feat'] = features
g.ndata['label'] = labels
g.ndata['train_mask'] = generate_mask_tensor(train_mask)
g.ndata['val_mask'] = generate_mask_tensor(val_mask)
g.ndata['test_mask'] = generate_mask_tensor(test_mask)
g.ndata["feat"] = features
g.ndata["label"] = labels
g.ndata["train_mask"] = generate_mask_tensor(train_mask)
g.ndata["val_mask"] = generate_mask_tensor(val_mask)
g.ndata["test_mask"] = generate_mask_tensor(test_mask)
if self._reorder:
self._graph = reorder_graph(
g, node_permute_algo='rcmk', edge_permute_algo='dst', store_ids=False)
g,
node_permute_algo="rcmk",
edge_permute_algo="dst",
store_ids=False,
)
else:
self._graph = g
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graph_path = os.path.join(self.save_path, "dgl_graph.bin")
return os.path.exists(graph_path)
def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graph_path = os.path.join(self.save_path, "dgl_graph.bin")
save_graphs(graph_path, self._graph)
def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graph_path = os.path.join(self.save_path, "dgl_graph.bin")
g, _ = load_graphs(graph_path)
self._graph = g[0]
......@@ -136,7 +150,7 @@ class YelpDataset(DGLBuiltinDataset):
return 1
def __getitem__(self, idx):
r""" Get graph object
r"""Get graph object
Parameters
----------
......
"""Package for dataloaders and samplers."""
from .. import backend as F
from .neighbor_sampler import *
from . import negative_sampler
from .base import *
from .cluster_gcn import *
from .graphsaint import *
from .neighbor_sampler import *
from .shadow import *
from .base import *
from . import negative_sampler
if F.get_preferred_backend() == 'pytorch':
if F.get_preferred_backend() == "pytorch":
from .dataloader import *
from .dist_dataloader import *
"""Cluster-GCN samplers."""
import os
import pickle
import numpy as np
from .. import backend as F
from ..base import DGLError
from ..partition import metis_partition_assignment
from .base import set_node_lazy_features, set_edge_lazy_features, Sampler
from .base import Sampler, set_edge_lazy_features, set_node_lazy_features
class ClusterGCNSampler(Sampler):
"""Cluster sampler from `Cluster-GCN: An Efficient Algorithm for Training
......@@ -59,35 +61,60 @@ class ClusterGCNSampler(Sampler):
>>> for subg in dataloader:
... train_on(subg)
"""
def __init__(self, g, k, cache_path='cluster_gcn.pkl', balance_ntypes=None,
balance_edges=False, mode='k-way', prefetch_ndata=None,
prefetch_edata=None, output_device=None):
def __init__(
self,
g,
k,
cache_path="cluster_gcn.pkl",
balance_ntypes=None,
balance_edges=False,
mode="k-way",
prefetch_ndata=None,
prefetch_edata=None,
output_device=None,
):
super().__init__()
if os.path.exists(cache_path):
try:
with open(cache_path, 'rb') as f:
self.partition_offset, self.partition_node_ids = pickle.load(f)
with open(cache_path, "rb") as f:
(
self.partition_offset,
self.partition_node_ids,
) = pickle.load(f)
except (EOFError, TypeError, ValueError):
raise DGLError(
f'The contents in the cache file {cache_path} is invalid. '
f'Please remove the cache file {cache_path} or specify another path.')
f"The contents in the cache file {cache_path} is invalid. "
f"Please remove the cache file {cache_path} or specify another path."
)
if len(self.partition_offset) != k + 1:
raise DGLError(
f'Number of partitions in the cache does not match the value of k. '
f'Please remove the cache file {cache_path} or specify another path.')
f"Number of partitions in the cache does not match the value of k. "
f"Please remove the cache file {cache_path} or specify another path."
)
if len(self.partition_node_ids) != g.num_nodes():
raise DGLError(
f'Number of nodes in the cache does not match the given graph. '
f'Please remove the cache file {cache_path} or specify another path.')
f"Number of nodes in the cache does not match the given graph. "
f"Please remove the cache file {cache_path} or specify another path."
)
else:
partition_ids = metis_partition_assignment(
g, k, balance_ntypes=balance_ntypes, balance_edges=balance_edges, mode=mode)
g,
k,
balance_ntypes=balance_ntypes,
balance_edges=balance_edges,
mode=mode,
)
partition_ids = F.asnumpy(partition_ids)
partition_node_ids = np.argsort(partition_ids)
partition_size = F.zerocopy_from_numpy(np.bincount(partition_ids, minlength=k))
partition_offset = F.zerocopy_from_numpy(np.insert(np.cumsum(partition_size), 0, 0))
partition_size = F.zerocopy_from_numpy(
np.bincount(partition_ids, minlength=k)
)
partition_offset = F.zerocopy_from_numpy(
np.insert(np.cumsum(partition_size), 0, 0)
)
partition_node_ids = F.zerocopy_from_numpy(partition_node_ids)
with open(cache_path, 'wb') as f:
with open(cache_path, "wb") as f:
pickle.dump((partition_offset, partition_node_ids), f)
self.partition_offset = partition_offset
self.partition_node_ids = partition_node_ids
......@@ -96,7 +123,7 @@ class ClusterGCNSampler(Sampler):
self.prefetch_edata = prefetch_edata or []
self.output_device = output_device
def sample(self, g, partition_ids): # pylint: disable=arguments-differ
def sample(self, g, partition_ids): # pylint: disable=arguments-differ
"""Sampling function.
Parameters
......@@ -111,10 +138,18 @@ class ClusterGCNSampler(Sampler):
DGLGraph
The sampled subgraph.
"""
node_ids = F.cat([
self.partition_node_ids[self.partition_offset[i]:self.partition_offset[i+1]]
for i in F.asnumpy(partition_ids)], 0)
sg = g.subgraph(node_ids, relabel_nodes=True, output_device=self.output_device)
node_ids = F.cat(
[
self.partition_node_ids[
self.partition_offset[i] : self.partition_offset[i + 1]
]
for i in F.asnumpy(partition_ids)
],
0,
)
sg = g.subgraph(
node_ids, relabel_nodes=True, output_device=self.output_device
)
set_node_lazy_features(sg, self.prefetch_ndata)
set_edge_lazy_features(sg, self.prefetch_edata)
return sg
"""GraphSAINT samplers."""
from ..base import DGLError
from ..random import choice
from ..sampling import random_walk, pack_traces
from .base import set_node_lazy_features, set_edge_lazy_features, Sampler
from ..sampling import pack_traces, random_walk
from .base import Sampler, set_edge_lazy_features, set_node_lazy_features
try:
import torch
except ImportError:
pass
class SAINTSampler(Sampler):
"""Random node/edge/walk sampler from
`GraphSAINT: Graph Sampling Based Inductive Learning Method
......@@ -65,18 +66,28 @@ class SAINTSampler(Sampler):
>>> for subg in dataloader:
... train_on(subg)
"""
def __init__(self, mode, budget, cache=True, prefetch_ndata=None,
prefetch_edata=None, output_device='cpu'):
def __init__(
self,
mode,
budget,
cache=True,
prefetch_ndata=None,
prefetch_edata=None,
output_device="cpu",
):
super().__init__()
self.budget = budget
if mode == 'node':
if mode == "node":
self.sampler = self.node_sampler
elif mode == 'edge':
elif mode == "edge":
self.sampler = self.edge_sampler
elif mode == 'walk':
elif mode == "walk":
self.sampler = self.walk_sampler
else:
raise DGLError(f"Expect mode to be 'node', 'edge' or 'walk', got {mode}.")
raise DGLError(
f"Expect mode to be 'node', 'edge' or 'walk', got {mode}."
)
self.cache = cache
self.prob = None
......@@ -95,8 +106,11 @@ class SAINTSampler(Sampler):
prob = g.out_degrees().float().clamp(min=1)
if self.cache:
self.prob = prob
return torch.multinomial(prob, num_samples=self.budget,
replacement=True).unique().type(g.idtype)
return (
torch.multinomial(prob, num_samples=self.budget, replacement=True)
.unique()
.type(g.idtype)
)
def edge_sampler(self, g):
"""Node ID sampler for random edge sampler"""
......@@ -107,11 +121,13 @@ class SAINTSampler(Sampler):
in_deg = g.in_degrees().float().clamp(min=1)
out_deg = g.out_degrees().float().clamp(min=1)
# We can reduce the sample space by half if graphs are always symmetric.
prob = 1. / in_deg[dst.long()] + 1. / out_deg[src.long()]
prob = 1.0 / in_deg[dst.long()] + 1.0 / out_deg[src.long()]
prob /= prob.sum()
if self.cache:
self.prob = prob
sampled_edges = torch.unique(choice(len(prob), size=self.budget, prob=prob))
sampled_edges = torch.unique(
choice(len(prob), size=self.budget, prob=prob)
)
sampled_nodes = torch.cat([src[sampled_edges], dst[sampled_edges]])
return sampled_nodes.unique().type(g.idtype)
......@@ -139,7 +155,9 @@ class SAINTSampler(Sampler):
The sampled subgraph.
"""
node_ids = self.sampler(g)
sg = g.subgraph(node_ids, relabel_nodes=True, output_device=self.output_device)
sg = g.subgraph(
node_ids, relabel_nodes=True, output_device=self.output_device
)
set_node_lazy_features(sg, self.prefetch_ndata)
set_edge_lazy_features(sg, self.prefetch_edata)
return sg
"""Negative samplers"""
from collections.abc import Mapping
from .. import backend as F
class _BaseNegativeSampler(object):
def _generate(self, g, eids, canonical_etype):
raise NotImplementedError
......@@ -25,12 +27,14 @@ class _BaseNegativeSampler(object):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
neg_pair = {k: self._generate(g, v, k) for k, v in eids.items()}
else:
assert len(g.canonical_etypes) == 1, \
'please specify a dict of etypes and ids for graphs with multiple edge types'
assert (
len(g.canonical_etypes) == 1
), "please specify a dict of etypes and ids for graphs with multiple edge types"
neg_pair = self._generate(g, eids, g.canonical_etypes[0])
return neg_pair
class PerSourceUniform(_BaseNegativeSampler):
"""Negative sampler that randomly chooses negative destination nodes
for each source node according to a uniform distribution.
......@@ -52,6 +56,7 @@ class PerSourceUniform(_BaseNegativeSampler):
>>> neg_sampler(g, torch.tensor([0, 1]))
(tensor([0, 0, 1, 1]), tensor([1, 0, 2, 3]))
"""
def __init__(self, k):
self.k = k
......@@ -66,9 +71,11 @@ class PerSourceUniform(_BaseNegativeSampler):
dst = F.randint(shape, dtype, ctx, 0, g.num_nodes(vtype))
return src, dst
# Alias
Uniform = PerSourceUniform
class GlobalUniform(_BaseNegativeSampler):
"""Negative sampler that randomly chooses negative source-destination pairs according
to a uniform distribution.
......@@ -104,6 +111,7 @@ class GlobalUniform(_BaseNegativeSampler):
>>> neg_sampler(g, torch.LongTensor([0, 1]))
(tensor([0, 1, 3, 2]), tensor([2, 0, 2, 1]))
"""
def __init__(self, k, exclude_self_loops=True, replace=False):
self.k = k
self.exclude_self_loops = exclude_self_loops
......@@ -111,4 +119,8 @@ class GlobalUniform(_BaseNegativeSampler):
def _generate(self, g, eids, canonical_etype):
return g.global_uniform_negative_sampling(
len(eids) * self.k, self.exclude_self_loops, self.replace, canonical_etype)
len(eids) * self.k,
self.exclude_self_loops,
self.replace,
canonical_etype,
)
"""
This package contains DistGNN and Libra based graph partitioning tools.
"""
from . import partition
from . import tools
from . import partition, tools
......@@ -18,17 +18,21 @@ from Xie et al.
# Nesreen K. Ahmed <nesreen.k.ahmed@intel.com>
# \cite Distributed Power-law Graph Computing: Theoretical and Empirical Analysis
import json
import os
import time
import json
import torch as th
from dgl import DGLGraph
from dgl.sparse import libra_vertex_cut
from dgl.sparse import libra2dgl_build_dict
from dgl.sparse import libra2dgl_set_lr
from dgl.sparse import libra2dgl_build_adjlist
from dgl.data.utils import save_graphs, save_tensors
from dgl.base import DGLError
from dgl.data.utils import save_graphs, save_tensors
from dgl.sparse import (
libra2dgl_build_adjlist,
libra2dgl_build_dict,
libra2dgl_set_lr,
libra_vertex_cut,
)
def libra_partition(num_community, G, resultdir):
......@@ -55,8 +59,8 @@ def libra_partition(num_community, G, resultdir):
3. The folder also contains a json file which contains partitions' information.
"""
num_nodes = G.number_of_nodes() # number of nodes
num_edges = G.number_of_edges() # number of edges
num_nodes = G.number_of_nodes() # number of nodes
num_edges = G.number_of_edges() # number of edges
print("Number of nodes in the graph: ", num_nodes)
print("Number of edges in the graph: ", num_edges)
......@@ -79,12 +83,23 @@ def libra_partition(num_community, G, resultdir):
## call to C/C++ code
out = th.zeros(u_t.shape[0], dtype=th.int32)
libra_vertex_cut(num_community, node_degree, edgenum_unassigned, community_weights,
u_t, v_t, weight_, out, num_nodes, num_edges, resultdir)
libra_vertex_cut(
num_community,
node_degree,
edgenum_unassigned,
community_weights,
u_t,
v_t,
weight_,
out,
num_nodes,
num_edges,
resultdir,
)
print("Max partition size: ", int(community_weights.max()))
print(" ** Converting libra partitions to dgl graphs **")
fsize = int(community_weights.max()) + 1024 ## max edges in partition
fsize = int(community_weights.max()) + 1024 ## max edges in partition
# print("fsize: ", fsize, flush=True)
node_map = th.zeros(num_community, dtype=th.int64)
......@@ -110,8 +125,20 @@ def libra_partition(num_community, G, resultdir):
## building node, parition dictionary
## Assign local node ids and mapping to global node ids
ret = libra2dgl_build_dict(a_t, b_t, indices, ldt_key, gdt_key, gdt_value,
node_map, offset, num_community, i, fsize, resultdir)
ret = libra2dgl_build_dict(
a_t,
b_t,
indices,
ldt_key,
gdt_key,
gdt_value,
node_map,
offset,
num_community,
i,
fsize,
resultdir,
)
num_nodes_partition = int(ret[0])
num_edges_partition = int(ret[1])
......@@ -123,23 +150,25 @@ def libra_partition(num_community, G, resultdir):
## fixing lr - 1-level tree for the split-nodes
libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, num_community, num_nodes)
########################################################
#graph_name = dataset
# graph_name = dataset
graph_name = resultdir.split("_")[-1].split("/")[0]
part_method = 'Libra'
num_parts = num_community ## number of paritions/communities
part_method = "Libra"
num_parts = num_community ## number of paritions/communities
num_hops = 0
node_map_val = node_map.tolist()
edge_map_val = 0
out_path = resultdir
part_metadata = {'graph_name': graph_name,
'num_nodes': G.number_of_nodes(),
'num_edges': G.number_of_edges(),
'part_method': part_method,
'num_parts': num_parts,
'halo_hops': num_hops,
'node_map': node_map_val,
'edge_map': edge_map_val}
part_metadata = {
"graph_name": graph_name,
"num_nodes": G.number_of_nodes(),
"num_edges": G.number_of_edges(),
"part_method": part_method,
"num_parts": num_parts,
"halo_hops": num_hops,
"node_map": node_map_val,
"edge_map": edge_map_val,
}
############################################################
for i in range(num_community):
......@@ -151,18 +180,18 @@ def libra_partition(num_community, G, resultdir):
ldt = ldt_ar[0]
try:
feat = G.ndata['feat']
feat = G.ndata["feat"]
except KeyError:
feat = G.ndata['features']
feat = G.ndata["features"]
try:
labels = G.ndata['label']
labels = G.ndata["label"]
except KeyError:
labels = G.ndata['labels']
labels = G.ndata["labels"]
trainm = G.ndata['train_mask'].int()
testm = G.ndata['test_mask'].int()
valm = G.ndata['val_mask'].int()
trainm = G.ndata["train_mask"].int()
testm = G.ndata["test_mask"].int()
valm = G.ndata["val_mask"].int()
feat_size = feat.shape[1]
gfeat = th.zeros([num_nodes_partition, feat_size], dtype=feat.dtype)
......@@ -174,21 +203,41 @@ def libra_partition(num_community, G, resultdir):
## build remote node databse per local node
## gather feats, train, test, val, and labels for each partition
libra2dgl_build_adjlist(feat, gfeat, adj, inner_node, ldt, gdt_key,
gdt_value, node_map, lr_t, lrtensor, num_nodes_partition,
num_community, i, feat_size, labels, trainm, testm, valm,
glabels, gtrainm, gtestm, gvalm, feat.shape[0])
g.ndata['adj'] = adj ## database of remote clones
g.ndata['inner_node'] = inner_node ## split node '0' else '1'
g.ndata['feat'] = gfeat ## gathered features
g.ndata['lf'] = lr_t ## 1-level tree among split nodes
g.ndata['label'] = glabels
g.ndata['train_mask'] = gtrainm
g.ndata['test_mask'] = gtestm
g.ndata['val_mask'] = gvalm
libra2dgl_build_adjlist(
feat,
gfeat,
adj,
inner_node,
ldt,
gdt_key,
gdt_value,
node_map,
lr_t,
lrtensor,
num_nodes_partition,
num_community,
i,
feat_size,
labels,
trainm,
testm,
valm,
glabels,
gtrainm,
gtestm,
gvalm,
feat.shape[0],
)
g.ndata["adj"] = adj ## database of remote clones
g.ndata["inner_node"] = inner_node ## split node '0' else '1'
g.ndata["feat"] = gfeat ## gathered features
g.ndata["lf"] = lr_t ## 1-level tree among split nodes
g.ndata["label"] = glabels
g.ndata["train_mask"] = gtrainm
g.ndata["test_mask"] = gtestm
g.ndata["val_mask"] = gvalm
# Validation code, run only small graphs
# for l in range(num_nodes_partition):
......@@ -207,9 +256,11 @@ def libra_partition(num_community, G, resultdir):
node_feat_file = os.path.join(part_dir, "node_feat.dgl")
edge_feat_file = os.path.join(part_dir, "edge_feat.dgl")
part_graph_file = os.path.join(part_dir, "graph.dgl")
part_metadata['part-{}'.format(part_id)] = {'node_feats': node_feat_file,
'edge_feats': edge_feat_file,
'part_graph': part_graph_file}
part_metadata["part-{}".format(part_id)] = {
"node_feats": node_feat_file,
"edge_feats": edge_feat_file,
"part_graph": part_graph_file,
}
os.makedirs(part_dir, mode=0o775, exist_ok=True)
save_tensors(node_feat_file, part.ndata)
save_graphs(part_graph_file, [part])
......@@ -219,7 +270,7 @@ def libra_partition(num_community, G, resultdir):
del ldt
del ldt_ar[0]
with open('{}/{}.json'.format(out_path, graph_name), 'w') as outfile:
with open("{}/{}.json".format(out_path, graph_name), "w") as outfile:
json.dump(part_metadata, outfile, sort_keys=True, indent=4)
print("Conversion libra2dgl completed !!!")
......@@ -270,7 +321,9 @@ def partition_graph(num_community, G, resultdir):
raise DGLError("Error: Could not create directory: ", resultdir)
tic = time.time()
print("####################################################################")
print(
"####################################################################"
)
print("Executing parititons: ", num_community)
ltic = time.time()
try:
......@@ -283,9 +336,18 @@ def partition_graph(num_community, G, resultdir):
libra_partition(num_community, G, resultdir)
ltoc = time.time()
print("Time taken by {} partitions {:0.4f} sec".format(num_community, ltoc - ltic))
print(
"Time taken by {} partitions {:0.4f} sec".format(
num_community, ltoc - ltic
)
)
print()
toc = time.time()
print("Generated ", num_community, " partitions in {:0.4f} sec".format(toc - tic), flush=True)
print(
"Generated ",
num_community,
" partitions in {:0.4f} sec".format(toc - tic),
flush=True,
)
print("Partitioning completed successfully !!!")
......@@ -7,13 +7,16 @@ Copyright (c) 2021 Intel Corporation
import os
import random
import requests
from scipy.io import mmread
import torch as th
from scipy.io import mmread
import dgl
from dgl.base import DGLError
from dgl.data.utils import load_graphs, save_graphs, save_tensors
def rep_per_node(prefix, num_community):
"""
Used on Libra partitioned data.
......@@ -24,17 +27,17 @@ def rep_per_node(prefix, num_community):
prefix: Partition folder location (contains replicationlist.csv)
num_community: number of partitions or communities
"""
ifile = os.path.join(prefix, 'replicationlist.csv')
ifile = os.path.join(prefix, "replicationlist.csv")
fhandle = open(ifile, "r")
r_dt = {}
fline = fhandle.readline() ## reading first line, contains the comment.
fline = fhandle.readline() ## reading first line, contains the comment.
print(fline)
for line in fhandle:
if line[0] == '#':
if line[0] == "#":
raise DGLError("[Bug] Read Hash char in rep_per_node func.")
node = line.strip('\n')
node = line.strip("\n")
if r_dt.get(node, -100) == -100:
r_dt[node] = 1
else:
......@@ -44,7 +47,9 @@ def rep_per_node(prefix, num_community):
## sanity checks
for v in r_dt.values():
if v >= num_community:
raise DGLError("[Bug] Unexpected event in rep_per_node() in tools.py.")
raise DGLError(
"[Bug] Unexpected event in rep_per_node() in tools.py."
)
return r_dt
......@@ -61,7 +66,9 @@ def download_proteins():
try:
req = requests.get(url)
except:
raise DGLError("Error: Failed to download Proteins dataset!! Aborting..")
raise DGLError(
"Error: Failed to download Proteins dataset!! Aborting.."
)
with open("proteins.mtx", "wb") as handle:
handle.write(req.content)
......@@ -73,7 +80,7 @@ def proteins_mtx2dgl():
"""
print("Converting mtx2dgl..")
print("This might a take while..")
a_mtx = mmread('proteins.mtx')
a_mtx = mmread("proteins.mtx")
coo = a_mtx.tocoo()
u = th.tensor(coo.row, dtype=th.int64)
v = th.tensor(coo.col, dtype=th.int64)
......@@ -82,7 +89,7 @@ def proteins_mtx2dgl():
g.add_edges(u, v)
n = g.number_of_nodes()
feat_size = 128 ## arbitrary number
feat_size = 128 ## arbitrary number
feats = th.empty([n, feat_size], dtype=th.float32)
## arbitrary numbers
......@@ -108,11 +115,11 @@ def proteins_mtx2dgl():
for i in range(n):
label[i] = random.choice(range(nlabels))
g.ndata['feat'] = feats
g.ndata['train_mask'] = train_mask
g.ndata['test_mask'] = test_mask
g.ndata['val_mask'] = val_mask
g.ndata['label'] = label
g.ndata["feat"] = feats
g.ndata["train_mask"] = train_mask
g.ndata["test_mask"] = test_mask
g.ndata["val_mask"] = val_mask
g.ndata["label"] = label
return g
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment