test_heterograph-apply-edges.py 9.82 KB
Newer Older
1
import itertools
2
3
import unittest
from collections import Counter
4
from itertools import product
5

6
import backend as F
7
8
9

import dgl
import dgl.function as fn
10
import networkx as nx
11
12
import numpy as np
import pytest
13
14
15
import scipy.sparse as spsp
import torch

16
from dgl import DGLError
17
from scipy.sparse import rand
18
from utils import get_cases, parametrize_idtype
19
20
21

rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
fill_value = {"sum": 0, "max": float("-inf")}
22
23
24
feat_size = 2


25
26
27
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
28
29
30
31
32
33
34
35
36
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

37
38
39
40
41
42
43
44
45
46
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
            ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
47
48
49
50
51
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def create_random_hetero_with_single_source_node_type(idtype):
    num_nodes = {"n1": 5, "n2": 10, "n3": 15}
    etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n1", "r3", "n2")]
    edges = {}
    for etype in etypes:
        src_ntype, _, dst_ntype = etype
        arr = spsp.random(
            num_nodes[src_ntype],
            num_nodes[dst_ntype],
            density=1,
            format="coo",
            random_state=100,
        )
        edges[etype] = (arr.row, arr.col)
    return dgl.heterograph(edges, idtype=idtype, device=F.ctx())


nv-dlasalle's avatar
nv-dlasalle committed
69
@parametrize_idtype
70
def test_unary_copy_u(idtype):
71
    def _test(mfunc):
72
73
        g = create_test_heterograph(idtype)

74
75
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
76
77
78

        F.attach_grad(x1)
        F.attach_grad(x2)
79
80
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
81
82

        #################################################################
83
        #  apply_edges() is called on each relation type separately
84
85
86
        #################################################################

        with F.record_grad():
87
88
89
90
91
            [
                g.apply_edges(fn.copy_u("h", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["plays"].edata["m"]
92
            F.backward(r1, F.ones(r1.shape))
93
            n_grad1 = F.grad(g.ndata["h"]["user"])
94
        # TODO (Israt): clear not working
95
        g.edata["m"].clear()
96
97

        #################################################################
98
        #  apply_edges() is called on all relation types
99
100
        #################################################################

101
102
        g.apply_edges(fn.copy_u("h", "m"))
        r2 = g["plays"].edata["m"]
103
        F.backward(r2, F.ones(r2.shape))
104
        n_grad2 = F.grad(g.nodes["user"].data["h"])
105
106
107

        # correctness check
        def _print_error(a, b):
108
109
110
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
111
                if not np.allclose(x, y):
112
                    print("@{} {} v.s. {}".format(i, x, y))
113
114
115
116
117

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if not F.allclose(n_grad1, n_grad2):
118
            print("node grad")
119
            _print_error(n_grad1, n_grad2)
120
        assert F.allclose(n_grad1, n_grad2)
121

122
    _test(fn.copy_u)
123
124


nv-dlasalle's avatar
nv-dlasalle committed
125
@parametrize_idtype
126
def test_unary_copy_e(idtype):
127
    def _test(mfunc):
128
129
130
        g = create_test_heterograph(idtype)
        feat_size = 2

131
132
133
134
        x1 = F.randn((4, feat_size))
        x2 = F.randn((4, feat_size))
        x3 = F.randn((3, feat_size))
        x4 = F.randn((3, feat_size))
135
136
137
138
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
139
140
141
142
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4
143
144

        #################################################################
145
        #  apply_edges() is called on each relation type separately
146
147
        #################################################################
        with F.record_grad():
148
149
150
151
152
            [
                g.apply_edges(fn.copy_e("eid", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["develops"].edata["m"]
153
            F.backward(r1, F.ones(r1.shape))
154
            e_grad1 = F.grad(g["develops"].edata["eid"])
155
156

        #################################################################
157
        #  apply_edges() is called on all relation types
158
159
        #################################################################

160
161
        g.apply_edges(fn.copy_e("eid", "m"))
        r2 = g["develops"].edata["m"]
162
        F.backward(r2, F.ones(r2.shape))
163
        e_grad2 = F.grad(g["develops"].edata["eid"])
164
165
166

        # # correctness check
        def _print_error(a, b):
167
168
169
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
170
                if not np.allclose(x, y):
171
                    print("@{} {} v.s. {}".format(i, x, y))
172
173
174
175
176

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if not F.allclose(e_grad1, e_grad2):
177
            print("edge grad")
178
            _print_error(e_grad1, e_grad2)
179
        assert F.allclose(e_grad1, e_grad2)
180

181
182
183
    _test(fn.copy_e)


nv-dlasalle's avatar
nv-dlasalle committed
184
@parametrize_idtype
185
186
187
188
def test_binary_op(idtype):
    def _test(lhs, rhs, binary_op):
        g = create_test_heterograph(idtype)

189
190
191
        n1 = F.randn((g.num_nodes("user"), feat_size))
        n2 = F.randn((g.num_nodes("developer"), feat_size))
        n3 = F.randn((g.num_nodes("game"), feat_size))
192

193
194
195
196
        x1 = F.randn((g.num_edges("plays"), feat_size))
        x2 = F.randn((g.num_edges("follows"), feat_size))
        x3 = F.randn((g.num_edges("develops"), feat_size))
        x4 = F.randn((g.num_edges("wishes"), feat_size))
197
198
199
200
201
202
203
204
205
206
207

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)

        #################################################################
        #  apply_edges() is called on each relation type separately
        #################################################################

        F.attach_grad(n1)
        F.attach_grad(n2)
        F.attach_grad(n3)
208
209
210
        g.nodes["user"].data["h"] = n1
        g.nodes["developer"].data["h"] = n2
        g.nodes["game"].data["h"] = n3
211
212
213
214
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
215
216
217
218
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
219
220

        with F.record_grad():
221
222
223
224
225
            [
                g.apply_edges(builtin_msg("h", "h", "m"), etype=rel)
                for rel in g.canonical_etypes
            ]
            r1 = g["plays"].edata["m"]
226
227
            loss = F.sum(r1.view(-1), 0)
            F.backward(loss)
228
            n_grad1 = F.grad(g.nodes["game"].data["h"])
229
230
231
232
233
234
235
236

        #################################################################
        #  apply_edges() is called on all relation types
        #################################################################

        F.attach_grad(n1)
        F.attach_grad(n2)
        F.attach_grad(n3)
237
238
239
        g.nodes["user"].data["h"] = n1
        g.nodes["developer"].data["h"] = n2
        g.nodes["game"].data["h"] = n3
240
241
242
243
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
244
245
246
247
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4
248
249

        with F.record_grad():
250
251
            g.apply_edges(builtin_msg("h", "h", "m"))
            r2 = g["plays"].edata["m"]
252
253
            loss = F.sum(r2.view(-1), 0)
            F.backward(loss)
254
            n_grad2 = F.grad(g.nodes["game"].data["h"])
255

256
257
        # correctness check
        def _print_error(a, b):
258
259
260
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
261
                if not np.allclose(x, y):
262
                    print("@{} {} v.s. {}".format(i, x, y))
263
264
265
266
267
268

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        if n_grad1 is not None or n_grad2 is not None:
            if not F.allclose(n_grad1, n_grad2):
269
                print("node grad")
270
                _print_error(n_grad1, n_grad2)
271
            assert F.allclose(n_grad1, n_grad2)
272
273
274
275
276
277
278
279

    target = ["u", "v", "e"]
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
        for binary_op in ["add", "sub", "mul", "div", "dot"]:
            print(lhs, rhs, binary_op)
            _test(lhs, rhs, binary_op)
280
281


282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
# Here we test heterograph with only single source node type because the format
# of node feature is a tensor.
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@parametrize_idtype
def test_heterograph_with_single_source_node_type_apply_edges(idtype):
    hg = create_random_hetero_with_single_source_node_type(idtype)

    hg.nodes["n1"].data["h"] = F.randn((hg.num_nodes("n1"), 1))
    hg.nodes["n2"].data["h"] = F.randn((hg.num_nodes("n2"), 1))
    hg.nodes["n3"].data["h"] = F.randn((hg.num_nodes("n3"), 1))

    assert type(hg.srcdata["h"]) == torch.Tensor
    hg.apply_edges(fn.u_add_v("h", "h", "x"))


299
if __name__ == "__main__":
300
301
    test_unary_copy_u()
    test_unary_copy_e()