test_graphbolt_utils.py 2.1 KB
Newer Older
1
import backend as F
2
import dgl.graphbolt as gb
3
import pytest
4
5
6
import torch


7
def test_find_reverse_edges_homo():
8
    edges = torch.tensor([[1, 3, 5], [2, 4, 5]]).T
9
    edges = gb.add_reverse_edges(edges)
10
11
    expected_edges = torch.tensor([[1, 3, 5, 2, 4, 5], [2, 4, 5, 1, 3, 5]]).T
    assert torch.equal(edges, expected_edges)
12
13
14
15
16
    assert torch.equal(edges[1], expected_edges[1])


def test_find_reverse_edges_hetero():
    edges = {
17
18
        "A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
        "B:rr:A": torch.tensor([[3], [3]]).T,
19
20
21
    }
    edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"})
    expected_edges = {
22
23
        "A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
        "B:rr:A": torch.tensor([[3, 2, 5], [3, 1, 5]]).T,
24
    }
25
26
    assert torch.equal(edges["A:r:B"], expected_edges["A:r:B"])
    assert torch.equal(edges["B:rr:A"], expected_edges["B:rr:A"])
27
28
29
30


def test_find_reverse_edges_bi_reverse_types():
    edges = {
31
32
        "A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
        "B:rr:A": torch.tensor([[3], [3]]).T,
33
34
35
    }
    edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A", "B:rr:A": "A:r:B"})
    expected_edges = {
36
37
        "A:r:B": torch.tensor([[1, 5, 3], [2, 5, 3]]).T,
        "B:rr:A": torch.tensor([[3, 2, 5], [3, 1, 5]]).T,
38
    }
39
40
    assert torch.equal(edges["A:r:B"], expected_edges["A:r:B"])
    assert torch.equal(edges["B:rr:A"], expected_edges["B:rr:A"])
41
42
43
44


def test_find_reverse_edges_circual_reverse_types():
    edges = {
45
46
47
        "A:r1:B": torch.tensor([[1, 1]]),
        "B:r2:C": torch.tensor([[2, 2]]),
        "C:r3:A": torch.tensor([[3, 3]]),
48
49
50
51
52
    }
    edges = gb.add_reverse_edges(
        edges, {"A:r1:B": "B:r2:C", "B:r2:C": "C:r3:A", "C:r3:A": "A:r1:B"}
    )
    expected_edges = {
53
54
55
        "A:r1:B": torch.tensor([[1, 3], [1, 3]]).T,
        "B:r2:C": torch.tensor([[2, 1], [2, 1]]).T,
        "C:r3:A": torch.tensor([[3, 2], [3, 2]]).T,
56
    }
57
58
59
60
    assert torch.equal(edges["A:r1:B"], expected_edges["A:r1:B"])
    assert torch.equal(edges["B:r2:C"], expected_edges["B:r2:C"])
    assert torch.equal(edges["A:r1:B"], expected_edges["A:r1:B"])
    assert torch.equal(edges["C:r3:A"], expected_edges["C:r3:A"])