test_udf.py 2.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import backend as F
import dgl
import networkx as nx
import dgl.utils as utils
from dgl import DGLGraph, ALL
from dgl.udf import NodeBatch, EdgeBatch

def test_node_batch():
    g = dgl.DGLGraph(nx.path_graph(20))
    feat = F.randn((g.number_of_nodes(), 10))
    g.ndata['x'] = feat

    # test all
Minjie Wang's avatar
Minjie Wang committed
14
    v = utils.toindex(slice(0, g.number_of_nodes()))
15
    n_repr = g.get_n_repr(v)
Minjie Wang's avatar
Minjie Wang committed
16
    nbatch = NodeBatch(v, n_repr)
17
18
19
20
21
22
23
24
25
    assert F.allclose(nbatch.data['x'], feat)
    assert nbatch.mailbox is None
    assert F.allclose(nbatch.nodes(), g.nodes())
    assert nbatch.batch_size() == g.number_of_nodes()
    assert len(nbatch) == g.number_of_nodes()

    # test partial
    v = utils.toindex(F.tensor([0, 3, 5, 7, 9]))
    n_repr = g.get_n_repr(v)
Minjie Wang's avatar
Minjie Wang committed
26
    nbatch = NodeBatch(v, n_repr)
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    assert F.allclose(nbatch.data['x'], F.gather_row(feat, F.tensor([0, 3, 5, 7, 9])))
    assert nbatch.mailbox is None
    assert F.allclose(nbatch.nodes(), F.tensor([0, 3, 5, 7, 9]))
    assert nbatch.batch_size() == 5
    assert len(nbatch) == 5

def test_edge_batch():
    d = 10
    g = dgl.DGLGraph(nx.path_graph(20))
    nfeat = F.randn((g.number_of_nodes(), d))
    efeat = F.randn((g.number_of_edges(), d))
    g.ndata['x'] = nfeat
    g.edata['x'] = efeat

    # test all
Minjie Wang's avatar
Minjie Wang committed
42
    eid = utils.toindex(slice(0, g.number_of_edges()))
43
44
45
46
47
    u, v, _ = g._graph.edges('eid')

    src_data = g.get_n_repr(u)
    edge_data = g.get_e_repr(eid)
    dst_data = g.get_n_repr(v)
Minjie Wang's avatar
Minjie Wang committed
48
    ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    assert F.shape(ebatch.src['x'])[0] == g.number_of_edges() and\
        F.shape(ebatch.src['x'])[1] == d
    assert F.shape(ebatch.dst['x'])[0] == g.number_of_edges() and\
        F.shape(ebatch.dst['x'])[1] == d
    assert F.shape(ebatch.data['x'])[0] == g.number_of_edges() and\
        F.shape(ebatch.data['x'])[1] == d
    assert F.allclose(ebatch.edges()[0], u.tousertensor())
    assert F.allclose(ebatch.edges()[1], v.tousertensor())
    assert F.allclose(ebatch.edges()[2], F.arange(0, g.number_of_edges()))
    assert ebatch.batch_size() == g.number_of_edges()
    assert len(ebatch) == g.number_of_edges()

    # test partial
    eid = utils.toindex(F.tensor([0, 3, 5, 7, 11, 13, 15, 27]))
    u, v, _ = g._graph.find_edges(eid)
    src_data = g.get_n_repr(u)
    edge_data = g.get_e_repr(eid)
    dst_data = g.get_n_repr(v)
Minjie Wang's avatar
Minjie Wang committed
67
    ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    assert F.shape(ebatch.src['x'])[0] == 8 and\
        F.shape(ebatch.src['x'])[1] == d
    assert F.shape(ebatch.dst['x'])[0] == 8 and\
        F.shape(ebatch.dst['x'])[1] == d
    assert F.shape(ebatch.data['x'])[0] == 8 and\
        F.shape(ebatch.data['x'])[1] == d
    assert F.allclose(ebatch.edges()[0], u.tousertensor())
    assert F.allclose(ebatch.edges()[1], v.tousertensor())
    assert F.allclose(ebatch.edges()[2], eid.tousertensor())
    assert ebatch.batch_size() == 8
    assert len(ebatch) == 8

if __name__ == '__main__':
    test_node_batch()
    test_edge_batch()