test_heterograph-apply-edges.py 8.62 KB
Newer Older
1
import itertools
2
3
import unittest
from collections import Counter
4
from itertools import product
5

6
import backend as F
7
8
9

import dgl
import dgl.function as fn
10
import networkx as nx
11
12
13
import numpy as np
import pytest
import scipy.sparse as ssp
14
from dgl import DGLError
15
from scipy.sparse import rand
16
from utils import get_cases, parametrize_idtype
17
18
19

rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
fill_value = {"sum": 0, "max": float("-inf")}
20
21
22
feat_size = 2


23
24
25
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
26
27
28
29
30
31
32
33
34
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

35
36
37
38
39
40
41
42
43
44
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
            ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
45
46
47
48
49
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


nv-dlasalle's avatar
nv-dlasalle committed
50
@parametrize_idtype
51
def test_unary_copy_u(idtype):
52
    def _test(mfunc):
53
54
        g = create_test_heterograph(idtype)

55
56
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
57
58
59

        F.attach_grad(x1)
        F.attach_grad(x2)
60
61
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
62
63

        #################################################################
64
        #  apply_edges() is called on each relation type separately
65
66
67
        #################################################################

        with F.record_grad():
68
69
70
71
72
            [
                g.apply_edges(fn.copy_u("h", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["plays"].edata["m"]
73
            F.backward(r1, F.ones(r1.shape))
74
            n_grad1 = F.grad(g.ndata["h"]["user"])
75
        # TODO (Israt): clear not working
76
        g.edata["m"].clear()
77
78

        #################################################################
79
        #  apply_edges() is called on all relation types
80
81
        #################################################################

82
83
        g.apply_edges(fn.copy_u("h", "m"))
        r2 = g["plays"].edata["m"]
84
        F.backward(r2, F.ones(r2.shape))
85
        n_grad2 = F.grad(g.nodes["user"].data["h"])
86
87
88

        # correctness check
        def _print_error(a, b):
89
90
91
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
92
                if not np.allclose(x, y):
93
                    print("@{} {} v.s. {}".format(i, x, y))
94
95
96
97
98

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if not F.allclose(n_grad1, n_grad2):
99
            print("node grad")
100
            _print_error(n_grad1, n_grad2)
101
        assert F.allclose(n_grad1, n_grad2)
102

103
    _test(fn.copy_u)
104
105


nv-dlasalle's avatar
nv-dlasalle committed
106
@parametrize_idtype
107
def test_unary_copy_e(idtype):
108
    def _test(mfunc):
109
110
111
        g = create_test_heterograph(idtype)
        feat_size = 2

112
113
114
115
        x1 = F.randn((4, feat_size))
        x2 = F.randn((4, feat_size))
        x3 = F.randn((3, feat_size))
        x4 = F.randn((3, feat_size))
116
117
118
119
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
120
121
122
123
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4
124
125

        #################################################################
126
        #  apply_edges() is called on each relation type separately
127
128
        #################################################################
        with F.record_grad():
129
130
131
132
133
            [
                g.apply_edges(fn.copy_e("eid", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["develops"].edata["m"]
134
            F.backward(r1, F.ones(r1.shape))
135
            e_grad1 = F.grad(g["develops"].edata["eid"])
136
137

        #################################################################
138
        #  apply_edges() is called on all relation types
139
140
        #################################################################

141
142
        g.apply_edges(fn.copy_e("eid", "m"))
        r2 = g["develops"].edata["m"]
143
        F.backward(r2, F.ones(r2.shape))
144
        e_grad2 = F.grad(g["develops"].edata["eid"])
145
146
147

        # # correctness check
        def _print_error(a, b):
148
149
150
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
151
                if not np.allclose(x, y):
152
                    print("@{} {} v.s. {}".format(i, x, y))
153
154
155
156
157

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if not F.allclose(e_grad1, e_grad2):
158
            print("edge grad")
159
            _print_error(e_grad1, e_grad2)
160
        assert F.allclose(e_grad1, e_grad2)
161

162
163
164
    _test(fn.copy_e)


nv-dlasalle's avatar
nv-dlasalle committed
165
@parametrize_idtype
166
167
168
169
def test_binary_op(idtype):
    def _test(lhs, rhs, binary_op):
        g = create_test_heterograph(idtype)

170
171
172
        n1 = F.randn((g.num_nodes("user"), feat_size))
        n2 = F.randn((g.num_nodes("developer"), feat_size))
        n3 = F.randn((g.num_nodes("game"), feat_size))
173

174
175
176
177
        x1 = F.randn((g.num_edges("plays"), feat_size))
        x2 = F.randn((g.num_edges("follows"), feat_size))
        x3 = F.randn((g.num_edges("develops"), feat_size))
        x4 = F.randn((g.num_edges("wishes"), feat_size))
178
179
180
181
182
183
184
185
186
187
188

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)

        #################################################################
        #  apply_edges() is called on each relation type separately
        #################################################################

        F.attach_grad(n1)
        F.attach_grad(n2)
        F.attach_grad(n3)
189
190
191
        g.nodes["user"].data["h"] = n1
        g.nodes["developer"].data["h"] = n2
        g.nodes["game"].data["h"] = n3
192
193
194
195
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
196
197
198
199
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
200
201

        with F.record_grad():
202
203
204
205
206
            [
                g.apply_edges(builtin_msg("h", "h", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["plays"].edata["m"]
207
208
            loss = F.sum(r1.view(-1), 0)
            F.backward(loss)
209
            n_grad1 = F.grad(g.nodes["game"].data["h"])
210
211
212
213
214
215
216
217

        #################################################################
        #  apply_edges() is called on all relation types
        #################################################################

        F.attach_grad(n1)
        F.attach_grad(n2)
        F.attach_grad(n3)
218
219
220
        g.nodes["user"].data["h"] = n1
        g.nodes["developer"].data["h"] = n2
        g.nodes["game"].data["h"] = n3
221
222
223
224
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
225
226
227
228
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
229
230

        with F.record_grad():
231
232
            g.apply_edges(builtin_msg("h", "h", "m"))
            r2 = g["plays"].edata["m"]
233
234
            loss = F.sum(r2.view(-1), 0)
            F.backward(loss)
235
            n_grad2 = F.grad(g.nodes["game"].data["h"])
236

237
238
        # correctness check
        def _print_error(a, b):
239
240
241
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
242
                if not np.allclose(x, y):
243
                    print("@{} {} v.s. {}".format(i, x, y))
244
245
246
247
248
249

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if n_grad1 is not None or n_grad2 is not None:
            if not F.allclose(n_grad1, n_grad2):
250
                print("node grad")
251
                _print_error(n_grad1, n_grad2)
252
            assert F.allclose(n_grad1, n_grad2)
253
254
255
256
257
258
259
260

    target = ["u", "v", "e"]
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
        for binary_op in ["add", "sub", "mul", "div", "dot"]:
            print(lhs, rhs, binary_op)
            _test(lhs, rhs, binary_op)
261
262


263
if __name__ == "__main__":
264
265
    test_unary_copy_u()
    test_unary_copy_e()