test_dataloader.py 1.34 KB
Newer Older
1
2
3
4
5
import unittest
import backend as F
import dgl
from dgl.dataloading import NeighborSampler, negative_sampler, \
    as_edge_prediction_sampler
nv-dlasalle's avatar
nv-dlasalle committed
6
from test_utils import parametrize_idtype
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

def create_test_graph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

    g = dgl.heterograph({
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
        ('user', 'wishes', 'game'): ([0, 2], [1, 0]),
        ('developer', 'develops', 'game'): ([0, 1], [0, 1])
    }, idtype=idtype, device=F.ctx())
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


nv-dlasalle's avatar
nv-dlasalle committed
28
@parametrize_idtype
29
30
31
32
33
34
35
36
37
38
39
40
41
def test_edge_prediction_sampler(idtype):
    g = create_test_graph(idtype)
    sampler = NeighborSampler([10,10])
    sampler = as_edge_prediction_sampler(
        sampler, negative_sampler=negative_sampler.Uniform(1))

    seeds = F.copy_to(F.arange(0, 2, dtype=idtype), ctx=F.ctx())
    # just a smoke test to make sure we don't fail internal assertions
    result = sampler.sample(g, {'follows': seeds})


if __name__ == '__main__':
    test_edge_prediction_sampler()