test_heterograph-apply-edges.py 8.65 KB
Newer Older
1
import itertools
2
3
import unittest
from collections import Counter
4
from itertools import product
5

6
import backend as F
7
8
9

import dgl
import dgl.function as fn
10
import networkx as nx
11
12
import numpy as np
import pytest
13
import pytests_utils
14
import scipy.sparse as ssp
15
from dgl import DGLError
16
from pytests_utils import get_cases, parametrize_idtype
17
from scipy.sparse import rand
18
19
20

rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
fill_value = {"sum": 0, "max": float("-inf")}
21
22
23
feat_size = 2


24
25
26
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
27
28
29
30
31
32
33
34
35
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

36
37
38
39
40
41
42
43
44
45
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
            ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
46
47
48
49
50
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


nv-dlasalle's avatar
nv-dlasalle committed
51
@parametrize_idtype
52
def test_unary_copy_u(idtype):
53
    def _test(mfunc):
54
55
        g = create_test_heterograph(idtype)

56
57
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
58
59
60

        F.attach_grad(x1)
        F.attach_grad(x2)
61
62
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
63
64

        #################################################################
65
        #  apply_edges() is called on each relation type separately
66
67
68
        #################################################################

        with F.record_grad():
69
70
71
72
73
            [
                g.apply_edges(fn.copy_u("h", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["plays"].edata["m"]
74
            F.backward(r1, F.ones(r1.shape))
75
            n_grad1 = F.grad(g.ndata["h"]["user"])
76
        # TODO (Israt): clear not working
77
        g.edata["m"].clear()
78
79

        #################################################################
80
        #  apply_edges() is called on all relation types
81
82
        #################################################################

83
84
        g.apply_edges(fn.copy_u("h", "m"))
        r2 = g["plays"].edata["m"]
85
        F.backward(r2, F.ones(r2.shape))
86
        n_grad2 = F.grad(g.nodes["user"].data["h"])
87
88
89

        # correctness check
        def _print_error(a, b):
90
91
92
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
93
                if not np.allclose(x, y):
94
                    print("@{} {} v.s. {}".format(i, x, y))
95
96
97
98
99

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if not F.allclose(n_grad1, n_grad2):
100
            print("node grad")
101
            _print_error(n_grad1, n_grad2)
102
        assert F.allclose(n_grad1, n_grad2)
103

104
    _test(fn.copy_u)
105
106


nv-dlasalle's avatar
nv-dlasalle committed
107
@parametrize_idtype
108
def test_unary_copy_e(idtype):
109
    def _test(mfunc):
110
111
112
        g = create_test_heterograph(idtype)
        feat_size = 2

113
114
115
116
        x1 = F.randn((4, feat_size))
        x2 = F.randn((4, feat_size))
        x3 = F.randn((3, feat_size))
        x4 = F.randn((3, feat_size))
117
118
119
120
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
121
122
123
124
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4
125
126

        #################################################################
127
        #  apply_edges() is called on each relation type separately
128
129
        #################################################################
        with F.record_grad():
130
131
132
133
134
            [
                g.apply_edges(fn.copy_e("eid", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["develops"].edata["m"]
135
            F.backward(r1, F.ones(r1.shape))
136
            e_grad1 = F.grad(g["develops"].edata["eid"])
137
138

        #################################################################
139
        #  apply_edges() is called on all relation types
140
141
        #################################################################

142
143
        g.apply_edges(fn.copy_e("eid", "m"))
        r2 = g["develops"].edata["m"]
144
        F.backward(r2, F.ones(r2.shape))
145
        e_grad2 = F.grad(g["develops"].edata["eid"])
146
147
148

        # # correctness check
        def _print_error(a, b):
149
150
151
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
152
                if not np.allclose(x, y):
153
                    print("@{} {} v.s. {}".format(i, x, y))
154
155
156
157
158

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if not F.allclose(e_grad1, e_grad2):
159
            print("edge grad")
160
            _print_error(e_grad1, e_grad2)
161
        assert F.allclose(e_grad1, e_grad2)
162

163
164
165
    _test(fn.copy_e)


nv-dlasalle's avatar
nv-dlasalle committed
166
@parametrize_idtype
167
168
169
170
def test_binary_op(idtype):
    def _test(lhs, rhs, binary_op):
        g = create_test_heterograph(idtype)

171
172
173
        n1 = F.randn((g.num_nodes("user"), feat_size))
        n2 = F.randn((g.num_nodes("developer"), feat_size))
        n3 = F.randn((g.num_nodes("game"), feat_size))
174

175
176
177
178
        x1 = F.randn((g.num_edges("plays"), feat_size))
        x2 = F.randn((g.num_edges("follows"), feat_size))
        x3 = F.randn((g.num_edges("develops"), feat_size))
        x4 = F.randn((g.num_edges("wishes"), feat_size))
179
180
181
182
183
184
185
186
187
188
189

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)

        #################################################################
        #  apply_edges() is called on each relation type separately
        #################################################################

        F.attach_grad(n1)
        F.attach_grad(n2)
        F.attach_grad(n3)
190
191
192
        g.nodes["user"].data["h"] = n1
        g.nodes["developer"].data["h"] = n2
        g.nodes["game"].data["h"] = n3
193
194
195
196
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
197
198
199
200
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
201
202

        with F.record_grad():
203
204
205
206
207
            [
                g.apply_edges(builtin_msg("h", "h", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["plays"].edata["m"]
208
209
            loss = F.sum(r1.view(-1), 0)
            F.backward(loss)
210
            n_grad1 = F.grad(g.nodes["game"].data["h"])
211
212
213
214
215
216
217
218

        #################################################################
        #  apply_edges() is called on all relation types
        #################################################################

        F.attach_grad(n1)
        F.attach_grad(n2)
        F.attach_grad(n3)
219
220
221
        g.nodes["user"].data["h"] = n1
        g.nodes["developer"].data["h"] = n2
        g.nodes["game"].data["h"] = n3
222
223
224
225
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
226
227
228
229
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
230
231

        with F.record_grad():
232
233
            g.apply_edges(builtin_msg("h", "h", "m"))
            r2 = g["plays"].edata["m"]
234
235
            loss = F.sum(r2.view(-1), 0)
            F.backward(loss)
236
            n_grad2 = F.grad(g.nodes["game"].data["h"])
237

238
239
        # correctness check
        def _print_error(a, b):
240
241
242
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
243
                if not np.allclose(x, y):
244
                    print("@{} {} v.s. {}".format(i, x, y))
245
246
247
248
249
250

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if n_grad1 is not None or n_grad2 is not None:
            if not F.allclose(n_grad1, n_grad2):
251
                print("node grad")
252
                _print_error(n_grad1, n_grad2)
253
            assert F.allclose(n_grad1, n_grad2)
254
255
256
257
258
259
260
261

    target = ["u", "v", "e"]
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
        for binary_op in ["add", "sub", "mul", "div", "dot"]:
            print(lhs, rhs, binary_op)
            _test(lhs, rhs, binary_op)
262
263


264
if __name__ == "__main__":
265
266
    test_unary_copy_u()
    test_unary_copy_e()