Unverified Commit a303f078 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

fix #2952 (#3010)

parent 17141dd3
......@@ -93,8 +93,17 @@ class LegacyTUDataset(DGLBuiltinDataset):
DS_indicator = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_indicator"), dtype=int))
if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_labels"), dtype=int))
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels
elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = np.genfromtxt(self._file_path("graph_attributes"), dtype=float)
self.num_labels = None
self.graph_labels = DS_graph_labels
else:
raise Exception("Unknown graph label or graph attributes")
g = dgl_graph(([], []))
g.add_nodes(int(DS_edge_list.max()) + 1)
......@@ -109,8 +118,6 @@ class LegacyTUDataset(DGLBuiltinDataset):
self.max_num_node = len(node_idx[0])
self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list]
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels
try:
DS_node_labels = self._idx_from_zero(
......@@ -296,8 +303,18 @@ class TUDataset(DGLBuiltinDataset):
loadtxt(self._file_path("A"), delimiter=",").astype(int))
DS_indicator = self._idx_from_zero(
loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int))
if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_reset(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = F.tensor(DS_graph_labels)
elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = loadtxt(self._file_path("graph_attributes"), delimiter=",").astype(float)
self.num_labels = None
self.graph_labels = F.tensor(DS_graph_labels)
else:
raise Exception("Unknown graph label or graph attributes")
g = dgl_graph(([], []))
g.add_nodes(int(DS_edge_list.max()) + 1)
......@@ -311,8 +328,6 @@ class TUDataset(DGLBuiltinDataset):
if len(node_idx[0]) > self.max_num_node:
self.max_num_node = len(node_idx[0])
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = F.tensor(DS_graph_labels)
self.attr_dict = {
'node_labels': ('ndata', 'node_labels'),
......
......@@ -24,6 +24,12 @@ def test_gin():
assert len(ds) == n_graphs, (len(ds), name)
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_tudataset_regression():
ds = data.TUDataset('ZINC_test', force_reload=True)
assert len(ds) == 5000
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_data_hash():
class HashTestDataset(data.DGLDataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment