checks.py 4.52 KB
Newer Older
1
import backend as F
2

3
import dgl
4
import pytest
5
from dgl.base import is_internal_column
6

7
8
9
10
11
12
13
__all__ = [
    "check_fail",
    "assert_is_identical",
    "assert_is_identical_hetero",
    "check_graph_equal",
]

14

15
16
17
18
19
20
def check_fail(fn, *args, **kwargs):
    try:
        fn(*args, **kwargs)
        return False
    except:
        return True
21

22

23
24
def assert_is_identical(g, g2):
    assert g.number_of_nodes() == g2.number_of_nodes()
25
26
    src, dst = g.all_edges(order="eid")
    src2, dst2 = g2.all_edges(order="eid")
27
28
29
30
31
32
33
34
35
36
    assert F.array_equal(src, src2)
    assert F.array_equal(dst, dst2)

    assert len(g.ndata) == len(g2.ndata)
    assert len(g.edata) == len(g2.edata)
    for k in g.ndata:
        assert F.allclose(g.ndata[k], g2.ndata[k])
    for k in g.edata:
        assert F.allclose(g.edata[k], g2.edata[k])

37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def assert_is_identical_hetero(g, g2, ignore_internal_data=False):
    assert g.ntypes == g2.ntypes
    assert g.canonical_etypes == g2.canonical_etypes

    # check if two metagraphs are identical
    for edges, features in g.metagraph().edges(keys=True).items():
        assert g2.metagraph().edges(keys=True)[edges] == features

    # check if node ID spaces and feature spaces are equal
    for ntype in g.ntypes:
        assert g.number_of_nodes(ntype) == g2.number_of_nodes(ntype)
        if ignore_internal_data:
            for k in list(g.nodes[ntype].data.keys()):
                if is_internal_column(k):
                    del g.nodes[ntype].data[k]
            for k in list(g2.nodes[ntype].data.keys()):
                if is_internal_column(k):
                    del g2.nodes[ntype].data[k]
        assert len(g.nodes[ntype].data) == len(g2.nodes[ntype].data)
        for k in g.nodes[ntype].data:
            assert F.allclose(g.nodes[ntype].data[k], g2.nodes[ntype].data[k])

    # check if edge ID spaces and feature spaces are equal
    for etype in g.canonical_etypes:
62
63
        src, dst = g.all_edges(etype=etype, order="eid")
        src2, dst2 = g2.all_edges(etype=etype, order="eid")
64
65
66
67
68
69
70
71
72
73
74
75
        assert F.array_equal(src, src2)
        assert F.array_equal(dst, dst2)
        if ignore_internal_data:
            for k in list(g.edges[etype].data.keys()):
                if is_internal_column(k):
                    del g.edges[etype].data[k]
            for k in list(g2.edges[etype].data.keys()):
                if is_internal_column(k):
                    del g2.edges[etype].data[k]
        assert len(g.edges[etype].data) == len(g2.edges[etype].data)
        for k in g.edges[etype].data:
            assert F.allclose(g.edges[etype].data[k], g2.edges[etype].data[k])
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


def check_graph_equal(g1, g2, *, check_idtype=True, check_feature=True):
    assert g1.device == g2.device
    if check_idtype:
        assert g1.idtype == g2.idtype
    assert g1.ntypes == g2.ntypes
    assert g1.etypes == g2.etypes
    assert g1.srctypes == g2.srctypes
    assert g1.dsttypes == g2.dsttypes
    assert g1.canonical_etypes == g2.canonical_etypes
    assert g1.batch_size == g2.batch_size

    # check if two metagraphs are identical
    for edges, features in g1.metagraph().edges(keys=True).items():
        assert g2.metagraph().edges(keys=True)[edges] == features

    for nty in g1.ntypes:
        assert g1.number_of_nodes(nty) == g2.number_of_nodes(nty)
        assert F.allclose(g1.batch_num_nodes(nty), g2.batch_num_nodes(nty))
    for ety in g1.canonical_etypes:
        assert g1.number_of_edges(ety) == g2.number_of_edges(ety)
        assert F.allclose(g1.batch_num_edges(ety), g2.batch_num_edges(ety))
        src1, dst1, eid1 = g1.edges(etype=ety, form="all")
        src2, dst2, eid2 = g2.edges(etype=ety, form="all")
        if check_idtype:
            assert F.allclose(src1, src2)
            assert F.allclose(dst1, dst2)
            assert F.allclose(eid1, eid2)
        else:
            assert F.allclose(src1, F.astype(src2, g1.idtype))
            assert F.allclose(dst1, F.astype(dst2, g1.idtype))
            assert F.allclose(eid1, F.astype(eid2, g1.idtype))

    if check_feature:
        for nty in g1.ntypes:
            if g1.number_of_nodes(nty) == 0:
                continue
            for feat_name in g1.nodes[nty].data.keys():
                assert F.allclose(
                    g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name]
                )
        for ety in g1.canonical_etypes:
            if g1.number_of_edges(ety) == 0:
                continue
            for feat_name in g2.edges[ety].data.keys():
                assert F.allclose(
                    g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]
                )