test_cached_graph.py 962 Bytes
Newer Older
1
2
3
4
5
import torch as th
import numpy as np
import networkx as nx
from dgl import DGLGraph
from dgl.cached_graph import *
Minjie Wang's avatar
Minjie Wang committed
6
from dgl.utils import Index
7
8
9
10
11
12
13
14
15
16
17
18

def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

def test_basics():
    g = DGLGraph()
    g.add_edge(0, 1)
    g.add_edge(1, 2)
    g.add_edge(1, 3)
    g.add_edge(2, 4)
    g.add_edge(2, 5)
Minjie Wang's avatar
Minjie Wang committed
19
    g.add_edge(0, 2)
20
    cg = create_cached_graph(g)
Minjie Wang's avatar
Minjie Wang committed
21
22
23
24
    u = Index(th.tensor([0, 0, 1, 1, 2, 2]))
    v = Index(th.tensor([1, 2, 2, 3, 4, 5]))
    check_eq(cg.get_edge_id(u, v).totensor(), th.tensor([0, 5, 1, 2, 3, 4]))
    query = Index(th.tensor([1, 2]))
25
    s, d = cg.in_edges(query)
Minjie Wang's avatar
Minjie Wang committed
26
27
    check_eq(s.totensor(), th.tensor([0, 0, 1]))
    check_eq(d.totensor(), th.tensor([1, 2, 2]))
28
    s, d = cg.out_edges(query)
Minjie Wang's avatar
Minjie Wang committed
29
30
    check_eq(s.totensor(), th.tensor([1, 1, 2, 2]))
    check_eq(d.totensor(), th.tensor([2, 3, 4, 5]))
Minjie Wang's avatar
Minjie Wang committed
31

32
33
if __name__ == '__main__':
    test_basics()