Unverified Commit 5f198763 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Bugfix] GINDataset check whether labels are all valid (#2319)

* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* check whether the original labels are valid

* add unit test for gin dataset

* bug of asscalar in mxnet<=1.6

* mxnet<=1.6 asscalar requires ndarray to be shaped (1,)

* skip gpu while testing datasets

* Update test_data.py

* test of gin dataset takes too much time
parent b8cc26e3
......@@ -56,6 +56,10 @@ def tensor(data, dtype=None):
return nd.array(data, dtype=dtype)
def as_scalar(data):
if data.size != 1:
raise ValueError("The current array is not a scalar")
if data.shape != (1,):
data = data.expand_dims(axis=0)
return data.asscalar()
def get_preferred_sparse_format():
......
......@@ -248,7 +248,8 @@ class GINDataset(DGLBuiltinDataset):
nlabel_set = nlabel_set.union(
set([F.as_scalar(nl) for nl in g.ndata['label']]))
nlabel_set = list(nlabel_set)
if len(nlabel_set) == np.max(nlabel_set) + 1 and np.min(nlabel_set) == 0:
is_label_valid = all([label in self.nlabel_dict for label in nlabel_set])
if is_label_valid and len(nlabel_set) == np.max(nlabel_set) + 1 and np.min(nlabel_set) == 0:
# Note this is different from the author's implementation. In weihua916's implementation,
# the labels are relabeled anyway. But here we didn't relabel it if the labels are contiguous
# to make it consistent with the original dataset
......
import dgl.data as data
import unittest
import backend as F
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_minigc():
ds = data.MiniGCDataset(16, 10, 20)
g, l = list(zip(*ds))
print(g, l)
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_gin():
ds_n_graphs = {
'MUTAG': 188,
'IMDBBINARY': 1000,
'IMDBMULTI': 1500,
'PROTEINS': 1113,
'PTC': 344,
}
for name, n_graphs in ds_n_graphs.items():
ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False)
assert len(ds) == n_graphs, (len(ds), name)
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_data_hash():
class HashTestDataset(data.DGLDataset):
def __init__(self, hash_key=()):
......@@ -20,4 +40,5 @@ def test_data_hash():
if __name__ == '__main__':
test_minigc()
test_gin()
test_data_hash()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment