Unverified Commit 8a227bfa authored by Henry Kenlay's avatar Henry Kenlay Committed by GitHub
Browse files

[Bugfix] fix TUDataset labelling issue (#2165) (#2173)



* [Bugfix] fix TUDataset labelling issue (#2165)

* [Bugfix] fix TUDataset labelling issue (dmlc#2165)

* update docstring according to discussion
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent a102db23
...@@ -276,7 +276,13 @@ class TUDataset(DGLBuiltinDataset): ...@@ -276,7 +276,13 @@ class TUDataset(DGLBuiltinDataset):
Notes Notes
----- -----
Graphs may have node labels, node attributes, edge labels, and edge attributes, Graphs may have node labels, node attributes, edge labels, and edge attributes,
varing from different dataset. This class does not perform additional process. varing from different dataset.
Labels are mapped to :math:`\lbrace 0,\cdots,n-1 \rbrace` where :math:`n` is the
number of labels (some datasets have raw labels :math:`\lbrace -1, 1 \rbrace` which
will be mapped to :math:`\lbrace 0, 1 \rbrace`). In previous versions, the minimum
label was added so that :math:`\lbrace -1, 1 \rbrace` was mapped to
:math:`\lbrace 0, 2 \rbrace`.
""" """
_url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip" _url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
...@@ -292,7 +298,7 @@ class TUDataset(DGLBuiltinDataset): ...@@ -292,7 +298,7 @@ class TUDataset(DGLBuiltinDataset):
loadtxt(self._file_path("A"), delimiter=",").astype(int)) loadtxt(self._file_path("A"), delimiter=",").astype(int))
DS_indicator = self._idx_from_zero( DS_indicator = self._idx_from_zero(
loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int))
DS_graph_labels = self._idx_from_zero( DS_graph_labels = self._idx_reset(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
g = dgl_graph(([], [])) g = dgl_graph(([], []))
...@@ -387,6 +393,14 @@ class TUDataset(DGLBuiltinDataset): ...@@ -387,6 +393,14 @@ class TUDataset(DGLBuiltinDataset):
def _idx_from_zero(idx_tensor): def _idx_from_zero(idx_tensor):
return idx_tensor - np.min(idx_tensor) return idx_tensor - np.min(idx_tensor)
@staticmethod
def _idx_reset(idx_tensor):
"""Maps n unique labels to {0, ..., n-1} in an ordered fashion."""
labels = np.unique(idx_tensor)
relabel_map = {x: i for i, x in enumerate(labels)}
new_idx_tensor = np.vectorize(relabel_map.get)(idx_tensor)
return new_idx_tensor
def statistics(self): def statistics(self):
return self.graph_lists[0].ndata['feat'].shape[1], \ return self.graph_lists[0].ndata['feat'].shape[1], \
self.num_labels, \ self.num_labels, \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment