test_heterograph-update-all.py 11.6 KB
Newer Older
1
import itertools
2
3
import unittest
from collections import Counter
4
from itertools import product
5

6
7
import backend as F
import networkx as nx
8
9
10
import numpy as np
import pytest
import scipy.sparse as ssp
11
12
import test_utils
from scipy.sparse import rand
13
14
15
16
17
18
19
from test_utils import get_cases, parametrize_idtype

import dgl
import dgl.function as fn
from dgl import DGLError

rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
20
21
22
feat_size = 2


23
24
25
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
26
27
28
29
30
31
32
33
34
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

35
36
37
38
39
40
41
42
43
44
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
            ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
45
46
47
48
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

49

50
51
def create_test_heterograph_2(idtype):

52
53
54
55
56
57
    src = np.random.randint(0, 50, 25)
    dst = np.random.randint(0, 50, 25)
    src1 = np.random.randint(0, 25, 10)
    dst1 = np.random.randint(0, 25, 10)
    src2 = np.random.randint(0, 100, 1000)
    dst2 = np.random.randint(0, 100, 1000)
58
59
60
61
62
63
64
65
66
67
68
    g = dgl.heterograph(
        {
            ("user", "becomes", "player"): (src, dst),
            ("user", "follows", "user"): (src, dst),
            ("user", "plays", "game"): (src, dst),
            ("user", "wishes", "game"): (src1, dst1),
            ("developer", "develops", "game"): (src2, dst2),
        },
        idtype=idtype,
        device=F.ctx(),
    )
69
70
71
72
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

73

74
75
76
77
def create_test_heterograph_large(idtype):

    src = np.random.randint(0, 50, 2500)
    dst = np.random.randint(0, 50, 2500)
78
79
80
81
82
83
84
85
86
87
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): (src, dst),
            ("user", "plays", "game"): (src, dst),
            ("user", "wishes", "game"): (src, dst),
            ("developer", "develops", "game"): (src, dst),
        },
        idtype=idtype,
        device=F.ctx(),
    )
88
89
90
91
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

92

nv-dlasalle's avatar
nv-dlasalle committed
93
@parametrize_idtype
94
95
def test_unary_copy_u(idtype):
    def _test(mfunc, rfunc):
96
97
98
99
        g = create_test_heterograph_2(idtype)
        g0 = create_test_heterograph(idtype)
        g1 = create_test_heterograph_large(idtype)
        cross_reducer = rfunc.__name__
100
101
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
102
103
        F.attach_grad(x1)
        F.attach_grad(x2)
104
105
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
106
107
108
109
110
111
112

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
113
114
115
116
117
118
119
120
121
                {
                    etype: (mfunc("h", "m"), rfunc("m", "y"))
                    for etype in g.canonical_etypes
                },
                cross_reducer,
            )
            r1 = g.nodes["game"].data["y"].clone()
            r2 = g.nodes["user"].data["y"].clone()
            r3 = g.nodes["player"].data["y"].clone()
122
123
            loss = r1.sum() + r2.sum() + r3.sum()
            F.backward(loss)
124
125
            n_grad1 = F.grad(g.nodes["user"].data["h"]).clone()
            n_grad2 = F.grad(g.nodes["developer"].data["h"]).clone()
126

127
128
129
130
        g.nodes["user"].data.clear()
        g.nodes["developer"].data.clear()
        g.nodes["game"].data.clear()
        g.nodes["player"].data.clear()
131
132
133
134
135

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

136
137
        F.attach_grad(x1)
        F.attach_grad(x2)
138
139
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
140

141
        with F.record_grad():
142
143
144
145
            g.update_all(mfunc("h", "m"), rfunc("m", "y"))
            r4 = g.nodes["game"].data["y"]
            r5 = g.nodes["user"].data["y"]
            r6 = g.nodes["player"].data["y"]
146
147
            loss = r4.sum() + r5.sum() + r6.sum()
            F.backward(loss)
148
149
            n_grad3 = F.grad(g.nodes["user"].data["h"])
            n_grad4 = F.grad(g.nodes["developer"].data["h"])
150
151
152
153

        assert F.allclose(r1, r4)
        assert F.allclose(r2, r5)
        assert F.allclose(r3, r6)
154
155
156
        assert F.allclose(n_grad1, n_grad3)
        assert F.allclose(n_grad2, n_grad4)

157
    _test(fn.copy_u, fn.sum)
158
159
    _test(fn.copy_u, fn.max)
    _test(fn.copy_u, fn.min)
160
161
    # _test('copy_u', 'mean')

162

nv-dlasalle's avatar
nv-dlasalle committed
163
@parametrize_idtype
164
165
166
def test_unary_copy_e(idtype):
    def _test(mfunc, rfunc):

167
168
169
170
        g = create_test_heterograph_large(idtype)
        g0 = create_test_heterograph_2(idtype)
        g1 = create_test_heterograph(idtype)
        cross_reducer = rfunc.__name__
171
172
173
174
        x1 = F.randn((g.num_edges("plays"), feat_size))
        x2 = F.randn((g.num_edges("follows"), feat_size))
        x3 = F.randn((g.num_edges("develops"), feat_size))
        x4 = F.randn((g.num_edges("wishes"), feat_size))
175
176
177
178
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
179
180
181
182
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4
183
184
185
186
187
188
189

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
190
191
192
193
194
195
196
197
198
199
                {
                    "plays": (mfunc("eid", "m"), rfunc("m", "y")),
                    "follows": (mfunc("eid", "m"), rfunc("m", "y")),
                    "develops": (mfunc("eid", "m"), rfunc("m", "y")),
                    "wishes": (mfunc("eid", "m"), rfunc("m", "y")),
                },
                cross_reducer,
            )
            r1 = g.nodes["game"].data["y"].clone()
            r2 = g.nodes["user"].data["y"].clone()
200
201
            loss = r1.sum() + r2.sum()
            F.backward(loss)
202
203
204
205
206
            e_grad1 = F.grad(g["develops"].edata["eid"]).clone()
            e_grad2 = F.grad(g["plays"].edata["eid"]).clone()
            e_grad3 = F.grad(g["wishes"].edata["eid"]).clone()
            e_grad4 = F.grad(g["follows"].edata["eid"]).clone()
        {etype: (g[etype].edata.clear()) for _, etype, _ in g.canonical_etypes},
207
208
209
210
211
212

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

        # TODO(Israt): output type can be None in multi_update and empty
213
214
215
216
217
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)

218
219
220
221
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4
222

223
        with F.record_grad():
224
225
226
            g.update_all(mfunc("eid", "m"), rfunc("m", "y"))
            r3 = g.nodes["game"].data["y"]
            r4 = g.nodes["user"].data["y"]
227
228
            loss = r3.sum() + r4.sum()
            F.backward(loss)
229
230
231
232
            e_grad5 = F.grad(g["develops"].edata["eid"])
            e_grad6 = F.grad(g["plays"].edata["eid"])
            e_grad7 = F.grad(g["wishes"].edata["eid"])
            e_grad8 = F.grad(g["follows"].edata["eid"])
233
234
        # # correctness check
        def _print_error(a, b):
235
236
237
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
238
                if not np.allclose(x, y):
239
                    print("@{} {} v.s. {}".format(i, x, y))
240

241
242
        assert F.allclose(r1, r3)
        assert F.allclose(r2, r4)
243
244
245
246
247
        assert F.allclose(e_grad1, e_grad5)
        assert F.allclose(e_grad2, e_grad6)
        assert F.allclose(e_grad3, e_grad7)
        assert F.allclose(e_grad4, e_grad8)

248
    _test(fn.copy_e, fn.sum)
249
250
    _test(fn.copy_e, fn.max)
    _test(fn.copy_e, fn.min)
251
252
    # _test('copy_e', 'mean')

253

nv-dlasalle's avatar
nv-dlasalle committed
254
@parametrize_idtype
255
256
257
258
259
def test_binary_op(idtype):
    def _test(lhs, rhs, binary_op, reducer):

        g = create_test_heterograph(idtype)

260
261
262
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
        x3 = F.randn((g.num_nodes("game"), feat_size))
263
264
265
266

        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
267
268
269
270
271
272
273
274
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
        g.nodes["game"].data["h"] = x3

        x1 = F.randn((4, feat_size))
        x2 = F.randn((4, feat_size))
        x3 = F.randn((3, feat_size))
        x4 = F.randn((3, feat_size))
275
276
277
278
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
279
280
281
282
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
283
284
285
286
287
288
289
290
291
292
293

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)
        builtin_red = getattr(fn, reducer)

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
294
295
296
297
298
299
300
                {
                    etype: (builtin_msg("h", "h", "m"), builtin_red("m", "y"))
                    for etype in g.canonical_etypes
                },
                "sum",
            )
            r1 = g.nodes["game"].data["y"]
301
302
303
304
305
306
307
            F.backward(r1, F.ones(r1.shape))
            n_grad1 = F.grad(r1)

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

308
309
        g.update_all(builtin_msg("h", "h", "m"), builtin_red("m", "y"))
        r2 = g.nodes["game"].data["y"]
310
311
312
313
        F.backward(r2, F.ones(r2.shape))
        n_grad2 = F.grad(r2)
        # correctness check
        def _print_error(a, b):
314
315
316
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
317
                if not np.allclose(x, y):
318
                    print("@{} {} v.s. {}".format(i, x, y))
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        # TODO (Israt): r1 and r2 have different frad func associated with
        # if not F.allclose(n_grad1, n_grad2):
        #     print('node grad')
        #     _print_error(n_grad1, n_grad2)
        # assert(F.allclose(n_grad1, n_grad2))

    target = ["u", "v", "e"]
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
        for binary_op in ["add", "sub", "mul", "div"]:
            # TODO(Israt) :Add support for reduce func "max", "min", "mean"
            for reducer in ["sum"]:
                print(lhs, rhs, binary_op, reducer)
                _test(lhs, rhs, binary_op, reducer)


340
if __name__ == "__main__":
341
342
    test_unary_copy_u()
    test_unary_copy_e()
343
    test_binary_op()