"tests/vscode:/vscode.git/clone" did not exist on "b0e02e5b5262e09d9c0e2e52a6a0a6f10525d29d"
test_cached_graph.py 1.04 KB
Newer Older
1
2
3
4
5
import torch as th
import numpy as np
import networkx as nx
from dgl import DGLGraph
from dgl.cached_graph import *
Minjie Wang's avatar
Minjie Wang committed
6
from dgl.utils import Index
7
8
9
10
11
12
13
14
15
16
17
18

def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

def test_basics():
    g = DGLGraph()
    g.add_edge(0, 1)
    g.add_edge(1, 2)
    g.add_edge(1, 3)
    g.add_edge(2, 4)
    g.add_edge(2, 5)
Minjie Wang's avatar
Minjie Wang committed
19
    g.add_edge(0, 2)
20
    cg = create_cached_graph(g)
Minjie Wang's avatar
Minjie Wang committed
21
22
23
    u = Index(th.tensor([0, 0, 1, 1, 2, 2]))
    v = Index(th.tensor([1, 2, 2, 3, 4, 5]))
    check_eq(cg.get_edge_id(u, v).totensor(), th.tensor([0, 5, 1, 2, 3, 4]))
24
25
26
27
28
29
30
31
32
    query = Index(th.tensor([0, 1, 2, 5]))
    s, d, orphan = cg.in_edges(query)
    check_eq(s.totensor(), th.tensor([0, 0, 1, 2]))
    check_eq(d.totensor(), th.tensor([1, 2, 2, 5]))
    assert orphan.tolist() == [0]
    s, d, orphan = cg.out_edges(query)
    check_eq(s.totensor(), th.tensor([0, 0, 1, 1, 2, 2]))
    check_eq(d.totensor(), th.tensor([1, 2, 2, 3, 4, 5]))
    assert orphan.tolist() == [5]
Minjie Wang's avatar
Minjie Wang committed
33

34
35
if __name__ == '__main__':
    test_basics()