"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "72780ff5b154f37194903078ef6caa5d65c653e3"
test_dataloader.py 1.41 KB
Newer Older
1
import unittest
2

3
4
import backend as F

5
6
7
8
import dgl
from dgl.dataloading import (
    as_edge_prediction_sampler,
    negative_sampler,
9
    NeighborSampler,
10
)
11
from test_utils import parametrize_idtype
12
13


14
15
16
17
18
19
20
21
22
def create_test_graph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

23
24
25
26
27
28
29
30
31
32
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 2], [1, 0]),
            ("developer", "develops", "game"): ([0, 1], [0, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
33
34
35
36
37
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


nv-dlasalle's avatar
nv-dlasalle committed
38
@parametrize_idtype
39
40
def test_edge_prediction_sampler(idtype):
    g = create_test_graph(idtype)
41
    sampler = NeighborSampler([10, 10])
42
    sampler = as_edge_prediction_sampler(
43
44
        sampler, negative_sampler=negative_sampler.Uniform(1)
    )
45
46
47

    seeds = F.copy_to(F.arange(0, 2, dtype=idtype), ctx=F.ctx())
    # just a smoke test to make sure we don't fail internal assertions
48
    result = sampler.sample(g, {"follows": seeds})
49
50


51
if __name__ == "__main__":
52
    test_edge_prediction_sampler()