"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "81125d8499b82da80e997c45c72ea54ebd8b8abb"
Unverified Commit 51ba6621 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Fix GINDT and #2087 (#2103)

* fix gindt

* ff

* fix

* minor fix

* fix
parent 628d9fc5
...@@ -15,6 +15,7 @@ from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info, down ...@@ -15,6 +15,7 @@ from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info, down
from ..utils import retry_method_with_fix from ..utils import retry_method_with_fix
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
class GINDataset(DGLBuiltinDataset): class GINDataset(DGLBuiltinDataset):
"""Datasets for Graph Isomorphism Network (GIN) """Datasets for Graph Isomorphism Network (GIN)
Adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_. Adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.
...@@ -232,32 +233,36 @@ class GINDataset(DGLBuiltinDataset): ...@@ -232,32 +233,36 @@ class GINDataset(DGLBuiltinDataset):
if self.degree_as_nlabel: if self.degree_as_nlabel:
if self.verbose: if self.verbose:
print('generate node features by node degree...') print('generate node features by node degree...')
nlabel_set = set([])
for g in self.graphs: for g in self.graphs:
# actually this label shouldn't be updated # actually this label shouldn't be updated
# in case users want to keep it # in case users want to keep it
# but usually no features means no labels, fine. # but usually no features means no labels, fine.
g.ndata['label'] = g.in_degrees() g.ndata['label'] = g.in_degrees()
# extracting unique node labels # extracting unique node labels
nlabel_set = nlabel_set.union(set([F.as_scalar(nl) for nl in g.ndata['label']]))
nlabel_set = list(nlabel_set) # in case the labels/degrees are not continuous number
# in case the labels/degrees are not continuous number nlabel_set = set([])
self.ndegree_dict = { for g in self.graphs:
nlabel_set = nlabel_set.union(
set([F.as_scalar(nl) for nl in g.ndata['label']]))
nlabel_set = list(nlabel_set)
if len(nlabel_set) == np.max(nlabel_set) + 1 and np.min(nlabel_set) == 0:
# Note this is different from the author's implementation. In weihua916's implementation,
# the labels are relabeled anyway. But here we didn't relabel it if the labels are contiguous
# to make it consistent with the original dataset
label2idx = self.nlabel_dict
else:
label2idx = {
nlabel_set[i]: i nlabel_set[i]: i
for i in range(len(nlabel_set)) for i in range(len(nlabel_set))
} }
label2idx = self.ndegree_dict
# generate node attr by node label # generate node attr by node label
else:
if self.verbose:
print('generate node features by node label...')
label2idx = self.nlabel_dict
for g in self.graphs: for g in self.graphs:
g.ndata['attr'] = F.tensor(np.zeros(( attr = np.zeros((
g.number_of_nodes(), len(label2idx)))) g.number_of_nodes(), len(label2idx)))
g.ndata['attr'][range(g.number_of_nodes()), [label2idx[F.as_scalar(F.reshape(nl, (1,)))] for nl in g.ndata['label']]] = 1 attr[range(g.number_of_nodes()), [label2idx[nl]
for nl in F.asnumpy(g.ndata['label']).tolist()]] = 1
g.ndata['attr'] = F.tensor(attr)
# after load, get the #classes and #dim # after load, get the #classes and #dim
self.gclasses = len(self.glabel_dict) self.gclasses = len(self.glabel_dict)
...@@ -288,8 +293,10 @@ class GINDataset(DGLBuiltinDataset): ...@@ -288,8 +293,10 @@ class GINDataset(DGLBuiltinDataset):
self.nlabel_dict, self.ndegree_dict)) self.nlabel_dict, self.ndegree_dict))
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash)) graph_path = os.path.join(
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
label_dict = {'labels': self.labels} label_dict = {'labels': self.labels}
info_dict = {'N': self.N, info_dict = {'N': self.N,
'n': self.n, 'n': self.n,
...@@ -308,8 +315,10 @@ class GINDataset(DGLBuiltinDataset): ...@@ -308,8 +315,10 @@ class GINDataset(DGLBuiltinDataset):
save_info(str(info_path), info_dict) save_info(str(info_path), info_dict)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash)) graph_path = os.path.join(
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
graphs, label_dict = load_graphs(str(graph_path)) graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path)) info_dict = load_info(str(info_path))
...@@ -331,8 +340,10 @@ class GINDataset(DGLBuiltinDataset): ...@@ -331,8 +340,10 @@ class GINDataset(DGLBuiltinDataset):
self.degree_as_nlabel = info_dict['degree_as_nlabel'] self.degree_as_nlabel = info_dict['degree_as_nlabel']
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash)) graph_path = os.path.join(
info_path = os.path.join(self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash)) self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
if os.path.exists(graph_path) and os.path.exists(info_path): if os.path.exists(graph_path) and os.path.exists(info_path):
return True return True
return False return False
...@@ -108,11 +108,11 @@ def save_graphs(filename, g_list, labels=None): ...@@ -108,11 +108,11 @@ def save_graphs(filename, g_list, labels=None):
load_graphs load_graphs
""" """
# if it is local file, do some sanity check # if it is local file, do some sanity check
if filename.startswith('s3://') is False: if not filename.startswith('s3://'):
if os.path.isdir(filename): if os.path.isdir(filename):
raise DGLError("Filename {} is an existing directory.".format(filename)) raise DGLError("Filename {} is an existing directory.".format(filename))
f_path, _ = os.path.split(filename) f_path = os.path.dirname(filename)
if not os.path.exists(f_path): if f_path and not os.path.exists(f_path):
os.makedirs(f_path) os.makedirs(f_path)
g_sample = g_list[0] if isinstance(g_list, list) else g_list g_sample = g_list[0] if isinstance(g_list, list) else g_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment