Unverified Commit 15d05be3 authored by Zhen Liu's avatar Zhen Liu Committed by GitHub
Browse files

Fix num_labels to num_classes in dataset files (#6666)

parent 5e64481b
...@@ -59,7 +59,7 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -59,7 +59,7 @@ class PPIDataset(DGLBuiltinDataset):
Examples Examples
-------- --------
>>> dataset = PPIDataset(mode='valid') >>> dataset = PPIDataset(mode='valid')
>>> num_labels = dataset.num_labels >>> num_classes = dataset.num_classes
>>> for g in dataset: >>> for g in dataset:
.... feat = g.ndata['feat'] .... feat = g.ndata['feat']
.... label = g.ndata['label'] .... label = g.ndata['label']
...@@ -173,6 +173,10 @@ class PPIDataset(DGLBuiltinDataset): ...@@ -173,6 +173,10 @@ class PPIDataset(DGLBuiltinDataset):
def num_labels(self): def num_labels(self):
return 121 return 121
@property
def num_classes(self):
return 121
def __len__(self): def __len__(self):
"""Return number of samples in this dataset.""" """Return number of samples in this dataset."""
return len(self.graphs) return len(self.graphs)
......
...@@ -141,6 +141,11 @@ class QM7bDataset(DGLDataset): ...@@ -141,6 +141,11 @@ class QM7bDataset(DGLDataset):
"""Number of prediction tasks.""" """Number of prediction tasks."""
return 14 return 14
@property
def num_classes(self):
"""Number of prediction tasks."""
return 14
def __getitem__(self, idx): def __getitem__(self, idx):
r"""Get graph and label by index r"""Get graph and label by index
......
...@@ -157,6 +157,16 @@ class QM9Dataset(DGLDataset): ...@@ -157,6 +157,16 @@ class QM9Dataset(DGLDataset):
""" """
return self.label.shape[1] return self.label.shape[1]
@property
def num_classes(self):
r"""
Returns
--------
int
Number of prediction tasks.
"""
return self.label.shape[1]
@property @property
def num_tasks(self): def num_tasks(self):
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment