test_heterograph-update-all.py 11.6 KB
Newer Older
1
import itertools
2
3
import unittest
from collections import Counter
4
from itertools import product
5

6
import backend as F
7
8
9

import dgl
import dgl.function as fn
10
import networkx as nx
11
12
13
import numpy as np
import pytest
import scipy.sparse as ssp
14
from dgl import DGLError
15
from scipy.sparse import rand
16
from utils import get_cases, parametrize_idtype
17
18

rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
19
20
21
feat_size = 2


22
23
24
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
25
26
27
28
29
30
31
32
33
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

34
35
36
37
38
39
40
41
42
43
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
            ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
44
45
46
47
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

48

49
def create_test_heterograph_2(idtype):
50
51
52
53
54
55
    src = np.random.randint(0, 50, 25)
    dst = np.random.randint(0, 50, 25)
    src1 = np.random.randint(0, 25, 10)
    dst1 = np.random.randint(0, 25, 10)
    src2 = np.random.randint(0, 100, 1000)
    dst2 = np.random.randint(0, 100, 1000)
56
57
58
59
60
61
62
63
64
65
66
    g = dgl.heterograph(
        {
            ("user", "becomes", "player"): (src, dst),
            ("user", "follows", "user"): (src, dst),
            ("user", "plays", "game"): (src, dst),
            ("user", "wishes", "game"): (src1, dst1),
            ("developer", "develops", "game"): (src2, dst2),
        },
        idtype=idtype,
        device=F.ctx(),
    )
67
68
69
70
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

71

72
73
74
def create_test_heterograph_large(idtype):
    src = np.random.randint(0, 50, 2500)
    dst = np.random.randint(0, 50, 2500)
75
76
77
78
79
80
81
82
83
84
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): (src, dst),
            ("user", "plays", "game"): (src, dst),
            ("user", "wishes", "game"): (src, dst),
            ("developer", "develops", "game"): (src, dst),
        },
        idtype=idtype,
        device=F.ctx(),
    )
85
86
87
88
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

89

nv-dlasalle's avatar
nv-dlasalle committed
90
@parametrize_idtype
91
92
def test_unary_copy_u(idtype):
    def _test(mfunc, rfunc):
93
94
95
96
        g = create_test_heterograph_2(idtype)
        g0 = create_test_heterograph(idtype)
        g1 = create_test_heterograph_large(idtype)
        cross_reducer = rfunc.__name__
97
98
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
99
100
        F.attach_grad(x1)
        F.attach_grad(x2)
101
102
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
103
104
105
106
107
108
109

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
110
111
112
113
114
115
116
117
118
                {
                    etype: (mfunc("h", "m"), rfunc("m", "y"))
                    for etype in g.canonical_etypes
                },
                cross_reducer,
            )
            r1 = g.nodes["game"].data["y"].clone()
            r2 = g.nodes["user"].data["y"].clone()
            r3 = g.nodes["player"].data["y"].clone()
119
120
            loss = r1.sum() + r2.sum() + r3.sum()
            F.backward(loss)
121
122
            n_grad1 = F.grad(g.nodes["user"].data["h"]).clone()
            n_grad2 = F.grad(g.nodes["developer"].data["h"]).clone()
123

124
125
126
127
        g.nodes["user"].data.clear()
        g.nodes["developer"].data.clear()
        g.nodes["game"].data.clear()
        g.nodes["player"].data.clear()
128
129
130
131
132

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

133
134
        F.attach_grad(x1)
        F.attach_grad(x2)
135
136
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
137

138
        with F.record_grad():
139
140
141
142
            g.update_all(mfunc("h", "m"), rfunc("m", "y"))
            r4 = g.nodes["game"].data["y"]
            r5 = g.nodes["user"].data["y"]
            r6 = g.nodes["player"].data["y"]
143
144
            loss = r4.sum() + r5.sum() + r6.sum()
            F.backward(loss)
145
146
            n_grad3 = F.grad(g.nodes["user"].data["h"])
            n_grad4 = F.grad(g.nodes["developer"].data["h"])
147
148
149
150

        assert F.allclose(r1, r4)
        assert F.allclose(r2, r5)
        assert F.allclose(r3, r6)
151
152
153
        assert F.allclose(n_grad1, n_grad3)
        assert F.allclose(n_grad2, n_grad4)

154
    _test(fn.copy_u, fn.sum)
155
156
    _test(fn.copy_u, fn.max)
    _test(fn.copy_u, fn.min)
157
158
    # _test('copy_u', 'mean')

159

nv-dlasalle's avatar
nv-dlasalle committed
160
@parametrize_idtype
161
162
def test_unary_copy_e(idtype):
    def _test(mfunc, rfunc):
163
164
165
166
        g = create_test_heterograph_large(idtype)
        g0 = create_test_heterograph_2(idtype)
        g1 = create_test_heterograph(idtype)
        cross_reducer = rfunc.__name__
167
168
169
170
        x1 = F.randn((g.num_edges("plays"), feat_size))
        x2 = F.randn((g.num_edges("follows"), feat_size))
        x3 = F.randn((g.num_edges("develops"), feat_size))
        x4 = F.randn((g.num_edges("wishes"), feat_size))
171
172
173
174
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
175
176
177
178
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4
179
180
181
182
183
184
185

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
186
187
188
189
190
191
192
193
194
195
                {
                    "plays": (mfunc("eid", "m"), rfunc("m", "y")),
                    "follows": (mfunc("eid", "m"), rfunc("m", "y")),
                    "develops": (mfunc("eid", "m"), rfunc("m", "y")),
                    "wishes": (mfunc("eid", "m"), rfunc("m", "y")),
                },
                cross_reducer,
            )
            r1 = g.nodes["game"].data["y"].clone()
            r2 = g.nodes["user"].data["y"].clone()
196
197
            loss = r1.sum() + r2.sum()
            F.backward(loss)
198
199
200
201
202
            e_grad1 = F.grad(g["develops"].edata["eid"]).clone()
            e_grad2 = F.grad(g["plays"].edata["eid"]).clone()
            e_grad3 = F.grad(g["wishes"].edata["eid"]).clone()
            e_grad4 = F.grad(g["follows"].edata["eid"]).clone()
        {etype: (g[etype].edata.clear()) for _, etype, _ in g.canonical_etypes},
203
204
205
206
207
208

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

        # TODO(Israt): output type can be None in multi_update and empty
209
210
211
212
213
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)

214
215
216
217
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4
218

219
        with F.record_grad():
220
221
222
            g.update_all(mfunc("eid", "m"), rfunc("m", "y"))
            r3 = g.nodes["game"].data["y"]
            r4 = g.nodes["user"].data["y"]
223
224
            loss = r3.sum() + r4.sum()
            F.backward(loss)
225
226
227
228
            e_grad5 = F.grad(g["develops"].edata["eid"])
            e_grad6 = F.grad(g["plays"].edata["eid"])
            e_grad7 = F.grad(g["wishes"].edata["eid"])
            e_grad8 = F.grad(g["follows"].edata["eid"])
229

230
231
        # # correctness check
        def _print_error(a, b):
232
233
234
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
235
                if not np.allclose(x, y):
236
                    print("@{} {} v.s. {}".format(i, x, y))
237

238
239
        assert F.allclose(r1, r3)
        assert F.allclose(r2, r4)
240
241
242
243
244
        assert F.allclose(e_grad1, e_grad5)
        assert F.allclose(e_grad2, e_grad6)
        assert F.allclose(e_grad3, e_grad7)
        assert F.allclose(e_grad4, e_grad8)

245
    _test(fn.copy_e, fn.sum)
246
247
    _test(fn.copy_e, fn.max)
    _test(fn.copy_e, fn.min)
248
249
    # _test('copy_e', 'mean')

250

nv-dlasalle's avatar
nv-dlasalle committed
251
@parametrize_idtype
252
253
254
255
def test_binary_op(idtype):
    def _test(lhs, rhs, binary_op, reducer):
        g = create_test_heterograph(idtype)

256
257
258
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
        x3 = F.randn((g.num_nodes("game"), feat_size))
259
260
261
262

        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
263
264
265
266
267
268
269
270
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
        g.nodes["game"].data["h"] = x3

        x1 = F.randn((4, feat_size))
        x2 = F.randn((4, feat_size))
        x3 = F.randn((3, feat_size))
        x4 = F.randn((3, feat_size))
271
272
273
274
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
275
276
277
278
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
279
280
281
282
283
284
285
286
287
288
289

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)
        builtin_red = getattr(fn, reducer)

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
290
291
292
293
294
295
296
                {
                    etype: (builtin_msg("h", "h", "m"), builtin_red("m", "y"))
                    for etype in g.canonical_etypes
                },
                "sum",
            )
            r1 = g.nodes["game"].data["y"]
297
298
299
300
301
302
303
            F.backward(r1, F.ones(r1.shape))
            n_grad1 = F.grad(r1)

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

304
305
        g.update_all(builtin_msg("h", "h", "m"), builtin_red("m", "y"))
        r2 = g.nodes["game"].data["y"]
306
307
        F.backward(r2, F.ones(r2.shape))
        n_grad2 = F.grad(r2)
308

309
310
        # correctness check
        def _print_error(a, b):
311
312
313
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
314
                if not np.allclose(x, y):
315
                    print("@{} {} v.s. {}".format(i, x, y))
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        # TODO (Israt): r1 and r2 have different frad func associated with
        # if not F.allclose(n_grad1, n_grad2):
        #     print('node grad')
        #     _print_error(n_grad1, n_grad2)
        # assert(F.allclose(n_grad1, n_grad2))

    target = ["u", "v", "e"]
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
        for binary_op in ["add", "sub", "mul", "div"]:
            # TODO(Israt) :Add support for reduce func "max", "min", "mean"
            for reducer in ["sum"]:
                print(lhs, rhs, binary_op, reducer)
                _test(lhs, rhs, binary_op, reducer)


337
if __name__ == "__main__":
338
339
    test_unary_copy_u()
    test_unary_copy_e()
340
    test_binary_op()