test_edge_softmax_hetero.py 5.1 KB
Newer Older
1
import itertools
2
3
4
5
import math
import unittest
from collections import Counter

6
7
import backend as F
import networkx as nx
8
9
10
import numpy as np
import pytest
import scipy.sparse as ssp
11
12
import test_utils
from scipy.sparse import rand
13
from test_utils import get_cases, parametrize_idtype
14

15
16
17
18
19
20
21
import dgl
import dgl.function as fn
from dgl import DGLError
from dgl.ops import edge_softmax

rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
fill_value = {"sum": 0, "max": float("-inf")}
22
23
feat_size = 2

24

25
26
27
28
29
30
31
32
33
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

34
35
36
37
38
39
40
41
42
43
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1, 2, 1, 1], [0, 0, 1, 1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
            ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
44
45
46
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g
47
48
49
50
51


@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
52
def test_edge_softmax_unidirectional():
53
54
55
56
57
58
59
60
61
62
63
64
    g = dgl.heterograph(
        {
            ("A", "AB", "B"): (
                [1, 2, 3, 1, 2, 3, 1, 2, 3],
                [0, 0, 0, 1, 1, 1, 2, 2, 2],
            ),
            ("B", "BB", "B"): (
                [0, 1, 2, 0, 1, 2, 0, 1, 2],
                [0, 0, 0, 1, 1, 1, 2, 2, 2],
            ),
        }
    )
65
    g = g.to(F.ctx())
66
67
68
69
70
    g.edges["AB"].data["x"] = F.ones(9) * 2
    g.edges["BB"].data["x"] = F.ones(9)
    result = dgl.ops.edge_softmax(
        g, {"AB": g.edges["AB"].data["x"], "BB": g.edges["BB"].data["x"]}
    )
71

72
73
    ab = result["A", "AB", "B"]
    bb = result["B", "BB", "B"]
74
75
76
77
    e2 = F.zeros_like(ab) + math.exp(2) / ((math.exp(2) + math.exp(1)) * 3)
    e1 = F.zeros_like(bb) + math.exp(1) / ((math.exp(2) + math.exp(1)) * 3)
    assert F.allclose(ab, e2)
    assert F.allclose(bb, e1)
78

79

80
81
82
83
84
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@pytest.mark.parametrize("g", get_cases(["clique"]))
@pytest.mark.parametrize("norm_by", ["src", "dst"])
85
# @pytest.mark.parametrize('shp', edge_softmax_shapes)
nv-dlasalle's avatar
nv-dlasalle committed
86
@parametrize_idtype
87
88
89
90
91
def test_edge_softmax(g, norm_by, idtype):
    print("params", norm_by, idtype)

    g = create_test_heterograph(idtype)

92
93
94
95
    x1 = F.randn((g.num_edges("plays"), feat_size))
    x2 = F.randn((g.num_edges("follows"), feat_size))
    x3 = F.randn((g.num_edges("develops"), feat_size))
    x4 = F.randn((g.num_edges("wishes"), feat_size))
96
97
98
99
100
101

    F.attach_grad(F.clone(x1))
    F.attach_grad(F.clone(x2))
    F.attach_grad(F.clone(x3))
    F.attach_grad(F.clone(x4))

102
103
104
105
    g["plays"].edata["eid"] = x1
    g["follows"].edata["eid"] = x2
    g["develops"].edata["eid"] = x3
    g["wishes"].edata["eid"] = x4
106
107
108
109
110
111
112
113
114
115

    #################################################################
    #  edge_softmax() on homogeneous graph
    #################################################################

    with F.record_grad():
        hm_g = dgl.to_homogeneous(g)
        hm_x = F.cat((x3, x2, x1, x4), 0)
        hm_e = F.attach_grad(F.clone(hm_x))
        score_hm = edge_softmax(hm_g, hm_e, norm_by=norm_by)
116
        hm_g.edata["score"] = score_hm
117
        ht_g = dgl.to_heterogeneous(hm_g, g.ntypes, g.etypes)
118
119
120
121
        r1 = ht_g.edata["score"][("user", "plays", "game")]
        r2 = ht_g.edata["score"][("user", "follows", "user")]
        r3 = ht_g.edata["score"][("developer", "develops", "game")]
        r4 = ht_g.edata["score"][("user", "wishes", "game")]
122
123
124
125
126
127
128
129
130
131
132
        F.backward(F.reduce_sum(r1) + F.reduce_sum(r2))
        grad_edata_hm = F.grad(hm_e)

    #################################################################
    #  edge_softmax() on heterogeneous graph
    #################################################################

    e1 = F.attach_grad(F.clone(x1))
    e2 = F.attach_grad(F.clone(x2))
    e3 = F.attach_grad(F.clone(x3))
    e4 = F.attach_grad(F.clone(x4))
133
134
135
136
137
138
    e = {
        ("user", "follows", "user"): e2,
        ("user", "plays", "game"): e1,
        ("user", "wishes", "game"): e4,
        ("developer", "develops", "game"): e3,
    }
139
140
    with F.record_grad():
        score = edge_softmax(g, e, norm_by=norm_by)
141
142
143
144
        r5 = score[("user", "plays", "game")]
        r6 = score[("user", "follows", "user")]
        r7 = score[("developer", "develops", "game")]
        r8 = score[("user", "wishes", "game")]
145
        F.backward(F.reduce_sum(r5) + F.reduce_sum(r6))
146
147
148
        grad_edata_ht = F.cat(
            (F.grad(e3), F.grad(e2), F.grad(e1), F.grad(e4)), 0
        )
149
150
151
152
153
154
155
        # correctness check
        assert F.allclose(r1, r5)
        assert F.allclose(r2, r6)
        assert F.allclose(r3, r7)
        assert F.allclose(r4, r8)
        assert F.allclose(grad_edata_hm, grad_edata_ht)

156
157

if __name__ == "__main__":
158
    test_edge_softmax_unidirectional()