test_heterograph-apply-edges.py 8.64 KB
Newer Older
1
import itertools
2
3
import unittest
from collections import Counter
4
from itertools import product
5

6
7
import backend as F
import networkx as nx
8
9
10
import numpy as np
import pytest
import scipy.sparse as ssp
11
12
import test_utils
from scipy.sparse import rand
13
14
15
16
17
from test_utils import get_cases, parametrize_idtype

import dgl
import dgl.function as fn
from dgl import DGLError
18

19
20
rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
fill_value = {"sum": 0, "max": float("-inf")}
21
22
23
feat_size = 2


24
25
26
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
27
28
29
30
31
32
33
34
35
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

36
37
38
39
40
41
42
43
44
45
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
            ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
46
47
48
49
50
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


nv-dlasalle's avatar
nv-dlasalle committed
51
@parametrize_idtype
52
def test_unary_copy_u(idtype):
53
    def _test(mfunc):
54
55
56

        g = create_test_heterograph(idtype)

57
58
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
59
60
61

        F.attach_grad(x1)
        F.attach_grad(x2)
62
63
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
64
65

        #################################################################
66
        #  apply_edges() is called on each relation type separately
67
68
69
        #################################################################

        with F.record_grad():
70
71
72
73
74
            [
                g.apply_edges(fn.copy_u("h", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["plays"].edata["m"]
75
            F.backward(r1, F.ones(r1.shape))
76
            n_grad1 = F.grad(g.ndata["h"]["user"])
77
        # TODO (Israt): clear not working
78
        g.edata["m"].clear()
79
80

        #################################################################
81
        #  apply_edges() is called on all relation types
82
83
        #################################################################

84
85
        g.apply_edges(fn.copy_u("h", "m"))
        r2 = g["plays"].edata["m"]
86
        F.backward(r2, F.ones(r2.shape))
87
        n_grad2 = F.grad(g.nodes["user"].data["h"])
88
89
90

        # correctness check
        def _print_error(a, b):
91
92
93
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
94
                if not np.allclose(x, y):
95
                    print("@{} {} v.s. {}".format(i, x, y))
96
97
98
99
100

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if not F.allclose(n_grad1, n_grad2):
101
            print("node grad")
102
            _print_error(n_grad1, n_grad2)
103
        assert F.allclose(n_grad1, n_grad2)
104

105
    _test(fn.copy_u)
106
107


nv-dlasalle's avatar
nv-dlasalle committed
108
@parametrize_idtype
109
def test_unary_copy_e(idtype):
110
    def _test(mfunc):
111
112
113
114

        g = create_test_heterograph(idtype)
        feat_size = 2

115
116
117
118
        x1 = F.randn((4, feat_size))
        x2 = F.randn((4, feat_size))
        x3 = F.randn((3, feat_size))
        x4 = F.randn((3, feat_size))
119
120
121
122
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
123
124
125
126
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4
127
128

        #################################################################
129
        #  apply_edges() is called on each relation type separately
130
131
        #################################################################
        with F.record_grad():
132
133
134
135
136
            [
                g.apply_edges(fn.copy_e("eid", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["develops"].edata["m"]
137
            F.backward(r1, F.ones(r1.shape))
138
            e_grad1 = F.grad(g["develops"].edata["eid"])
139
140

        #################################################################
141
        #  apply_edges() is called on all relation types
142
143
        #################################################################

144
145
        g.apply_edges(fn.copy_e("eid", "m"))
        r2 = g["develops"].edata["m"]
146
        F.backward(r2, F.ones(r2.shape))
147
        e_grad2 = F.grad(g["develops"].edata["eid"])
148
149
150

        # # correctness check
        def _print_error(a, b):
151
152
153
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
154
                if not np.allclose(x, y):
155
                    print("@{} {} v.s. {}".format(i, x, y))
156
157
158
159
160

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if not F.allclose(e_grad1, e_grad2):
161
            print("edge grad")
162
            _print_error(e_grad1, e_grad2)
163
        assert F.allclose(e_grad1, e_grad2)
164

165
166
167
    _test(fn.copy_e)


nv-dlasalle's avatar
nv-dlasalle committed
168
@parametrize_idtype
169
170
171
172
173
def test_binary_op(idtype):
    def _test(lhs, rhs, binary_op):

        g = create_test_heterograph(idtype)

174
175
176
        n1 = F.randn((g.num_nodes("user"), feat_size))
        n2 = F.randn((g.num_nodes("developer"), feat_size))
        n3 = F.randn((g.num_nodes("game"), feat_size))
177

178
179
180
181
        x1 = F.randn((g.num_edges("plays"), feat_size))
        x2 = F.randn((g.num_edges("follows"), feat_size))
        x3 = F.randn((g.num_edges("develops"), feat_size))
        x4 = F.randn((g.num_edges("wishes"), feat_size))
182
183
184
185
186
187
188
189
190
191
192

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)

        #################################################################
        #  apply_edges() is called on each relation type separately
        #################################################################

        F.attach_grad(n1)
        F.attach_grad(n2)
        F.attach_grad(n3)
193
194
195
        g.nodes["user"].data["h"] = n1
        g.nodes["developer"].data["h"] = n2
        g.nodes["game"].data["h"] = n3
196
197
198
199
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
200
201
202
203
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
204
205

        with F.record_grad():
206
207
208
209
210
            [
                g.apply_edges(builtin_msg("h", "h", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["plays"].edata["m"]
211
212
            loss = F.sum(r1.view(-1), 0)
            F.backward(loss)
213
            n_grad1 = F.grad(g.nodes["game"].data["h"])
214
215
216
217
218
219
220
221

        #################################################################
        #  apply_edges() is called on all relation types
        #################################################################

        F.attach_grad(n1)
        F.attach_grad(n2)
        F.attach_grad(n3)
222
223
224
        g.nodes["user"].data["h"] = n1
        g.nodes["developer"].data["h"] = n2
        g.nodes["game"].data["h"] = n3
225
226
227
228
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
229
230
231
232
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
233
234

        with F.record_grad():
235
236
            g.apply_edges(builtin_msg("h", "h", "m"))
            r2 = g["plays"].edata["m"]
237
238
            loss = F.sum(r2.view(-1), 0)
            F.backward(loss)
239
            n_grad2 = F.grad(g.nodes["game"].data["h"])
240
241
        # correctness check
        def _print_error(a, b):
242
243
244
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
245
                if not np.allclose(x, y):
246
                    print("@{} {} v.s. {}".format(i, x, y))
247
248
249
250
251
252

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if n_grad1 is not None or n_grad2 is not None:
            if not F.allclose(n_grad1, n_grad2):
253
                print("node grad")
254
                _print_error(n_grad1, n_grad2)
255
            assert F.allclose(n_grad1, n_grad2)
256
257
258
259
260
261
262
263

    target = ["u", "v", "e"]
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
        for binary_op in ["add", "sub", "mul", "div", "dot"]:
            print(lhs, rhs, binary_op)
            _test(lhs, rhs, binary_op)
264
265


266
if __name__ == "__main__":
267
268
    test_unary_copy_u()
    test_unary_copy_e()