Unverified Commit c45f6eb5 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Bugfix] Fix duplicate edges in Coauthor Dataset #2553 (#2569)

* fix

* address comment
parent 6c23fba8
......@@ -7,6 +7,7 @@ from .dgl_dataset import DGLBuiltinDataset
from .utils import save_graphs, load_graphs, _get_dgl_url, deprecate_property, deprecate_class
from ..convert import graph as dgl_graph
from .. import backend as F
from .. import transform
__all__ = ["AmazonCoBuyComputerDataset", "AmazonCoBuyPhotoDataset", "CoauthorPhysicsDataset", "CoauthorCSDataset",
"CoraFullDataset", "AmazonCoBuy", "Coauthor", "CoraFull"]
......@@ -42,17 +43,17 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
self._print_info()
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin')
if os.path.exists(graph_path):
return True
return False
def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin')
save_graphs(graph_path, self._graph)
def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
graph_path = os.path.join(self.save_path, 'dgl_graph_v1.bin')
graphs, _ = load_graphs(graph_path)
self._graph = graphs[0]
self._data = [graphs[0]]
......@@ -91,9 +92,8 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
labels = loader['labels']
else:
labels = None
row = np.hstack([adj_matrix.row, adj_matrix.col])
col = np.hstack([adj_matrix.col, adj_matrix.row])
g = dgl_graph((row, col))
g = dgl_graph((adj_matrix.row, adj_matrix.col))
g = transform.to_bidirected(g)
g.ndata['feat'] = F.tensor(attr_matrix, F.data_type_dict['float32'])
g.ndata['label'] = F.tensor(labels, F.data_type_dict['int64'])
return g
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment