Unverified Commit c45f6eb5 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Bugfix] Fix duplicate edges in Coauthor Dataset #2553 (#2569)

* fix

* address comment
parent 6c23fba8
...@@ -7,6 +7,7 @@ from .dgl_dataset import DGLBuiltinDataset ...@@ -7,6 +7,7 @@ from .dgl_dataset import DGLBuiltinDataset
from .utils import save_graphs, load_graphs, _get_dgl_url, deprecate_property, deprecate_class from .utils import save_graphs, load_graphs, _get_dgl_url, deprecate_property, deprecate_class
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
from .. import backend as F from .. import backend as F
from .. import transform
__all__ = ["AmazonCoBuyComputerDataset", "AmazonCoBuyPhotoDataset", "CoauthorPhysicsDataset", "CoauthorCSDataset", __all__ = ["AmazonCoBuyComputerDataset", "AmazonCoBuyPhotoDataset", "CoauthorPhysicsDataset", "CoauthorCSDataset",
"CoraFullDataset", "AmazonCoBuy", "Coauthor", "CoraFull"] "CoraFullDataset", "AmazonCoBuy", "Coauthor", "CoraFull"]
...@@ -42,17 +43,17 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -42,17 +43,17 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
self._print_info() self._print_info()
def has_cache(self): def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin')
if os.path.exists(graph_path): if os.path.exists(graph_path):
return True return True
return False return False
def save(self): def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin')
save_graphs(graph_path, self._graph) save_graphs(graph_path, self._graph)
def load(self): def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin') graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin')
graphs, _ = load_graphs(graph_path) graphs, _ = load_graphs(graph_path)
self._graph = graphs[0] self._graph = graphs[0]
self._data = [graphs[0]] self._data = [graphs[0]]
...@@ -91,9 +92,8 @@ class GNNBenchmarkDataset(DGLBuiltinDataset): ...@@ -91,9 +92,8 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
labels = loader['labels'] labels = loader['labels']
else: else:
labels = None labels = None
row = np.hstack([adj_matrix.row, adj_matrix.col]) g = dgl_graph((adj_matrix.row, adj_matrix.col))
col = np.hstack([adj_matrix.col, adj_matrix.row]) g = transform.to_bidirected(g)
g = dgl_graph((row, col))
g.ndata['feat'] = F.tensor(attr_matrix, F.data_type_dict['float32']) g.ndata['feat'] = F.tensor(attr_matrix, F.data_type_dict['float32'])
g.ndata['label'] = F.tensor(labels, F.data_type_dict['int64']) g.ndata['label'] = F.tensor(labels, F.data_type_dict['int64'])
return g return g
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment