test_data.py 2.19 KB
Newer Older
1
import dgl.data as data
2
3
import unittest
import backend as F
4

5
6

@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
7
8
9
10
11
def test_minigc():
    ds = data.MiniGCDataset(16, 10, 20)
    g, l = list(zip(*ds))
    print(g, l)

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_gin():
    ds_n_graphs = {
        'MUTAG': 188,
        'IMDBBINARY': 1000,
        'IMDBMULTI': 1500,
        'PROTEINS': 1113,
        'PTC': 344,
    }
    for name, n_graphs in ds_n_graphs.items():
        ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False)
        assert len(ds) == n_graphs, (len(ds), name)


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_fraud():
    g = data.FraudDataset('amazon')[0]
    assert g.num_nodes() == 11944

    g = data.FraudAmazonDataset()[0]
    assert g.num_nodes() == 11944

    g = data.FraudYelpDataset()[0]
    assert g.num_nodes() == 45954


@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_fakenews():
    ds = data.FakeNewsDataset('politifact', 'bert')
    assert len(ds) == 314

    ds = data.FakeNewsDataset('gossipcop', 'profile')
    assert len(ds) == 5464

Jinjing Zhou's avatar
Jinjing Zhou committed
47
48
49
50
51
52

@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_tudataset_regression():    
    ds = data.TUDataset('ZINC_test', force_reload=True)
    assert len(ds) == 5000

53

54
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
55
56
57
58
59
60
61
62
63
64
65
66
67
def test_data_hash():
    class HashTestDataset(data.DGLDataset):
        def __init__(self, hash_key=()):
            super(HashTestDataset, self).__init__('hashtest', hash_key=hash_key)
        def _load(self):
            pass

    a = HashTestDataset((True, 0, '1', (1,2,3)))
    b = HashTestDataset((True, 0, '1', (1,2,3)))
    c = HashTestDataset((True, 0, '1', (1,2,4)))
    assert a.hash == b.hash
    assert a.hash != c.hash

68

69
if __name__ == '__main__':
70
    test_minigc()
71
    test_gin()
72
    test_data_hash()
73
74
75
    test_tudataset_regression()
    test_fraud()
    test_fakenews()