"src/array/vscode:/vscode.git/clone" did not exist on "314cedc1b1c3c5ffd2ee0a980010b62faf120f1f"
Unverified Commit a303f078 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

fix #2952 (#3010)

parent 17141dd3
...@@ -93,8 +93,17 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -93,8 +93,17 @@ class LegacyTUDataset(DGLBuiltinDataset):
DS_indicator = self._idx_from_zero( DS_indicator = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_indicator"), dtype=int)) np.genfromtxt(self._file_path("graph_indicator"), dtype=int))
if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_from_zero( DS_graph_labels = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_labels"), dtype=int)) np.genfromtxt(self._file_path("graph_labels"), dtype=int))
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels
elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = np.genfromtxt(self._file_path("graph_attributes"), dtype=float)
self.num_labels = None
self.graph_labels = DS_graph_labels
else:
raise Exception("Unknown graph label or graph attributes")
g = dgl_graph(([], [])) g = dgl_graph(([], []))
g.add_nodes(int(DS_edge_list.max()) + 1) g.add_nodes(int(DS_edge_list.max()) + 1)
...@@ -109,8 +118,6 @@ class LegacyTUDataset(DGLBuiltinDataset): ...@@ -109,8 +118,6 @@ class LegacyTUDataset(DGLBuiltinDataset):
self.max_num_node = len(node_idx[0]) self.max_num_node = len(node_idx[0])
self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list] self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list]
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels
try: try:
DS_node_labels = self._idx_from_zero( DS_node_labels = self._idx_from_zero(
...@@ -296,8 +303,18 @@ class TUDataset(DGLBuiltinDataset): ...@@ -296,8 +303,18 @@ class TUDataset(DGLBuiltinDataset):
loadtxt(self._file_path("A"), delimiter=",").astype(int)) loadtxt(self._file_path("A"), delimiter=",").astype(int))
DS_indicator = self._idx_from_zero( DS_indicator = self._idx_from_zero(
loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int))
if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_reset( DS_graph_labels = self._idx_reset(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = F.tensor(DS_graph_labels)
elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = loadtxt(self._file_path("graph_attributes"), delimiter=",").astype(float)
self.num_labels = None
self.graph_labels = F.tensor(DS_graph_labels)
else:
raise Exception("Unknown graph label or graph attributes")
g = dgl_graph(([], [])) g = dgl_graph(([], []))
g.add_nodes(int(DS_edge_list.max()) + 1) g.add_nodes(int(DS_edge_list.max()) + 1)
...@@ -311,8 +328,6 @@ class TUDataset(DGLBuiltinDataset): ...@@ -311,8 +328,6 @@ class TUDataset(DGLBuiltinDataset):
if len(node_idx[0]) > self.max_num_node: if len(node_idx[0]) > self.max_num_node:
self.max_num_node = len(node_idx[0]) self.max_num_node = len(node_idx[0])
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = F.tensor(DS_graph_labels)
self.attr_dict = { self.attr_dict = {
'node_labels': ('ndata', 'node_labels'), 'node_labels': ('ndata', 'node_labels'),
......
...@@ -24,6 +24,12 @@ def test_gin(): ...@@ -24,6 +24,12 @@ def test_gin():
assert len(ds) == n_graphs, (len(ds), name) assert len(ds) == n_graphs, (len(ds), name)
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_tudataset_regression():
ds = data.TUDataset('ZINC_test', force_reload=True)
assert len(ds) == 5000
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.") @unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_data_hash(): def test_data_hash():
class HashTestDataset(data.DGLDataset): class HashTestDataset(data.DGLDataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment