test_heterograph-update-all.py 11.6 KB
Newer Older
1
import itertools
2
3
import unittest
from collections import Counter
4
from itertools import product
5

6
import backend as F
7
8
9

import dgl
import dgl.function as fn
10
import networkx as nx
11
12
13
import numpy as np
import pytest
import scipy.sparse as ssp
14
import test_utils
15
from dgl import DGLError
16
from scipy.sparse import rand
17
18
19
from test_utils import get_cases, parametrize_idtype

rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
20
21
22
feat_size = 2


23
24
25
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
26
27
28
29
30
31
32
33
34
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

35
36
37
38
39
40
41
42
43
44
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
            ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
45
46
47
48
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

49

50
def create_test_heterograph_2(idtype):
51
52
53
54
55
56
    src = np.random.randint(0, 50, 25)
    dst = np.random.randint(0, 50, 25)
    src1 = np.random.randint(0, 25, 10)
    dst1 = np.random.randint(0, 25, 10)
    src2 = np.random.randint(0, 100, 1000)
    dst2 = np.random.randint(0, 100, 1000)
57
58
59
60
61
62
63
64
65
66
67
    g = dgl.heterograph(
        {
            ("user", "becomes", "player"): (src, dst),
            ("user", "follows", "user"): (src, dst),
            ("user", "plays", "game"): (src, dst),
            ("user", "wishes", "game"): (src1, dst1),
            ("developer", "develops", "game"): (src2, dst2),
        },
        idtype=idtype,
        device=F.ctx(),
    )
68
69
70
71
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

72

73
74
75
def create_test_heterograph_large(idtype):
    src = np.random.randint(0, 50, 2500)
    dst = np.random.randint(0, 50, 2500)
76
77
78
79
80
81
82
83
84
85
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): (src, dst),
            ("user", "plays", "game"): (src, dst),
            ("user", "wishes", "game"): (src, dst),
            ("developer", "develops", "game"): (src, dst),
        },
        idtype=idtype,
        device=F.ctx(),
    )
86
87
88
89
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g

90

nv-dlasalle's avatar
nv-dlasalle committed
91
@parametrize_idtype
92
93
def test_unary_copy_u(idtype):
    def _test(mfunc, rfunc):
94
95
96
97
        g = create_test_heterograph_2(idtype)
        g0 = create_test_heterograph(idtype)
        g1 = create_test_heterograph_large(idtype)
        cross_reducer = rfunc.__name__
98
99
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
100
101
        F.attach_grad(x1)
        F.attach_grad(x2)
102
103
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
104
105
106
107
108
109
110

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
111
112
113
114
115
116
117
118
119
                {
                    etype: (mfunc("h", "m"), rfunc("m", "y"))
                    for etype in g.canonical_etypes
                },
                cross_reducer,
            )
            r1 = g.nodes["game"].data["y"].clone()
            r2 = g.nodes["user"].data["y"].clone()
            r3 = g.nodes["player"].data["y"].clone()
120
121
            loss = r1.sum() + r2.sum() + r3.sum()
            F.backward(loss)
122
123
            n_grad1 = F.grad(g.nodes["user"].data["h"]).clone()
            n_grad2 = F.grad(g.nodes["developer"].data["h"]).clone()
124

125
126
127
128
        g.nodes["user"].data.clear()
        g.nodes["developer"].data.clear()
        g.nodes["game"].data.clear()
        g.nodes["player"].data.clear()
129
130
131
132
133

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

134
135
        F.attach_grad(x1)
        F.attach_grad(x2)
136
137
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
138

139
        with F.record_grad():
140
141
142
143
            g.update_all(mfunc("h", "m"), rfunc("m", "y"))
            r4 = g.nodes["game"].data["y"]
            r5 = g.nodes["user"].data["y"]
            r6 = g.nodes["player"].data["y"]
144
145
            loss = r4.sum() + r5.sum() + r6.sum()
            F.backward(loss)
146
147
            n_grad3 = F.grad(g.nodes["user"].data["h"])
            n_grad4 = F.grad(g.nodes["developer"].data["h"])
148
149
150
151

        assert F.allclose(r1, r4)
        assert F.allclose(r2, r5)
        assert F.allclose(r3, r6)
152
153
154
        assert F.allclose(n_grad1, n_grad3)
        assert F.allclose(n_grad2, n_grad4)

155
    _test(fn.copy_u, fn.sum)
156
157
    _test(fn.copy_u, fn.max)
    _test(fn.copy_u, fn.min)
158
159
    # _test('copy_u', 'mean')

160

nv-dlasalle's avatar
nv-dlasalle committed
161
@parametrize_idtype
162
163
def test_unary_copy_e(idtype):
    def _test(mfunc, rfunc):
164
165
166
167
        g = create_test_heterograph_large(idtype)
        g0 = create_test_heterograph_2(idtype)
        g1 = create_test_heterograph(idtype)
        cross_reducer = rfunc.__name__
168
169
170
171
        x1 = F.randn((g.num_edges("plays"), feat_size))
        x2 = F.randn((g.num_edges("follows"), feat_size))
        x3 = F.randn((g.num_edges("develops"), feat_size))
        x4 = F.randn((g.num_edges("wishes"), feat_size))
172
173
174
175
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
176
177
178
179
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4
180
181
182
183
184
185
186

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
187
188
189
190
191
192
193
194
195
196
                {
                    "plays": (mfunc("eid", "m"), rfunc("m", "y")),
                    "follows": (mfunc("eid", "m"), rfunc("m", "y")),
                    "develops": (mfunc("eid", "m"), rfunc("m", "y")),
                    "wishes": (mfunc("eid", "m"), rfunc("m", "y")),
                },
                cross_reducer,
            )
            r1 = g.nodes["game"].data["y"].clone()
            r2 = g.nodes["user"].data["y"].clone()
197
198
            loss = r1.sum() + r2.sum()
            F.backward(loss)
199
200
201
202
203
            e_grad1 = F.grad(g["develops"].edata["eid"]).clone()
            e_grad2 = F.grad(g["plays"].edata["eid"]).clone()
            e_grad3 = F.grad(g["wishes"].edata["eid"]).clone()
            e_grad4 = F.grad(g["follows"].edata["eid"]).clone()
        {etype: (g[etype].edata.clear()) for _, etype, _ in g.canonical_etypes},
204
205
206
207
208
209

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

        # TODO(Israt): output type can be None in multi_update and empty
210
211
212
213
214
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)

215
216
217
218
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4
219

220
        with F.record_grad():
221
222
223
            g.update_all(mfunc("eid", "m"), rfunc("m", "y"))
            r3 = g.nodes["game"].data["y"]
            r4 = g.nodes["user"].data["y"]
224
225
            loss = r3.sum() + r4.sum()
            F.backward(loss)
226
227
228
229
            e_grad5 = F.grad(g["develops"].edata["eid"])
            e_grad6 = F.grad(g["plays"].edata["eid"])
            e_grad7 = F.grad(g["wishes"].edata["eid"])
            e_grad8 = F.grad(g["follows"].edata["eid"])
230

231
232
        # # correctness check
        def _print_error(a, b):
233
234
235
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
236
                if not np.allclose(x, y):
237
                    print("@{} {} v.s. {}".format(i, x, y))
238

239
240
        assert F.allclose(r1, r3)
        assert F.allclose(r2, r4)
241
242
243
244
245
        assert F.allclose(e_grad1, e_grad5)
        assert F.allclose(e_grad2, e_grad6)
        assert F.allclose(e_grad3, e_grad7)
        assert F.allclose(e_grad4, e_grad8)

246
    _test(fn.copy_e, fn.sum)
247
248
    _test(fn.copy_e, fn.max)
    _test(fn.copy_e, fn.min)
249
250
    # _test('copy_e', 'mean')

251

nv-dlasalle's avatar
nv-dlasalle committed
252
@parametrize_idtype
253
254
255
256
def test_binary_op(idtype):
    def _test(lhs, rhs, binary_op, reducer):
        g = create_test_heterograph(idtype)

257
258
259
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
        x3 = F.randn((g.num_nodes("game"), feat_size))
260
261
262
263

        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
264
265
266
267
268
269
270
271
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
        g.nodes["game"].data["h"] = x3

        x1 = F.randn((4, feat_size))
        x2 = F.randn((4, feat_size))
        x3 = F.randn((3, feat_size))
        x4 = F.randn((3, feat_size))
272
273
274
275
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
276
277
278
279
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
280
281
282
283
284
285
286
287
288
289
290

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)
        builtin_red = getattr(fn, reducer)

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
291
292
293
294
295
296
297
                {
                    etype: (builtin_msg("h", "h", "m"), builtin_red("m", "y"))
                    for etype in g.canonical_etypes
                },
                "sum",
            )
            r1 = g.nodes["game"].data["y"]
298
299
300
301
302
303
304
            F.backward(r1, F.ones(r1.shape))
            n_grad1 = F.grad(r1)

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

305
306
        g.update_all(builtin_msg("h", "h", "m"), builtin_red("m", "y"))
        r2 = g.nodes["game"].data["y"]
307
308
        F.backward(r2, F.ones(r2.shape))
        n_grad2 = F.grad(r2)
309

310
311
        # correctness check
        def _print_error(a, b):
312
313
314
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
315
                if not np.allclose(x, y):
316
                    print("@{} {} v.s. {}".format(i, x, y))
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        # TODO (Israt): r1 and r2 have different frad func associated with
        # if not F.allclose(n_grad1, n_grad2):
        #     print('node grad')
        #     _print_error(n_grad1, n_grad2)
        # assert(F.allclose(n_grad1, n_grad2))

    target = ["u", "v", "e"]
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
        for binary_op in ["add", "sub", "mul", "div"]:
            # TODO(Israt) :Add support for reduce func "max", "min", "mean"
            for reducer in ["sum"]:
                print(lhs, rhs, binary_op, reducer)
                _test(lhs, rhs, binary_op, reducer)


338
if __name__ == "__main__":
339
340
    test_unary_copy_u()
    test_unary_copy_e()
341
    test_binary_op()