TAHIN.py 8.23 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
3
4
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.function as fn
from dgl.nn.pytorch import GATConv

9
10

# Semantic attention in the metapath-based aggregation (the same as that in the HAN)
KounianhuaDu's avatar
KounianhuaDu committed
11
12
13
14
15
16
17
class SemanticAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128):
        super(SemanticAttention, self).__init__()

        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
18
            nn.Linear(hidden_size, 1, bias=False),
KounianhuaDu's avatar
KounianhuaDu committed
19
20
21
        )

    def forward(self, z):
22
        """
KounianhuaDu's avatar
KounianhuaDu committed
23
24
25
26
27
        Shape of z: (N, M , D*K)
        N: number of nodes
        M: number of metapath patterns
        D: hidden_size
        K: number of heads
28
29
30
31
32
33
        """
        w = self.project(z).mean(0)  # (M, 1)
        beta = torch.softmax(w, dim=0)  # (M, 1)
        beta = beta.expand((z.shape[0],) + beta.shape)  # (N, M, 1)

        return (beta * z).sum(1)  # (N, D * K)
KounianhuaDu's avatar
KounianhuaDu committed
34
35


36
# Metapath-based aggregation (the same as the HANLayer)
KounianhuaDu's avatar
KounianhuaDu committed
37
class HANLayer(nn.Module):
38
39
40
    def __init__(
        self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout
    ):
KounianhuaDu's avatar
KounianhuaDu committed
41
42
43
44
45
        super(HANLayer, self).__init__()

        # One GAT layer for each meta path based adjacency matrix
        self.gat_layers = nn.ModuleList()
        for i in range(len(meta_path_patterns)):
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
            self.gat_layers.append(
                GATConv(
                    in_size,
                    out_size,
                    layer_num_heads,
                    dropout,
                    dropout,
                    activation=F.elu,
                    allow_zero_in_degree=True,
                )
            )
        self.semantic_attention = SemanticAttention(
            in_size=out_size * layer_num_heads
        )
        self.meta_path_patterns = list(
            tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns
        )
KounianhuaDu's avatar
KounianhuaDu committed
63
64
65
66
67
68

        self._cached_graph = None
        self._cached_coalesced_graph = {}

    def forward(self, g, h):
        semantic_embeddings = []
69
        # obtain metapath reachable graph
KounianhuaDu's avatar
KounianhuaDu committed
70
71
72
73
        if self._cached_graph is None or self._cached_graph is not g:
            self._cached_graph = g
            self._cached_coalesced_graph.clear()
            for meta_path_pattern in self.meta_path_patterns:
74
75
76
                self._cached_coalesced_graph[
                    meta_path_pattern
                ] = dgl.metapath_reachable_graph(g, meta_path_pattern)
KounianhuaDu's avatar
KounianhuaDu committed
77
78
79
80

        for i, meta_path_pattern in enumerate(self.meta_path_patterns):
            new_g = self._cached_coalesced_graph[meta_path_pattern]
            semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
81
82
83
84
85
        semantic_embeddings = torch.stack(
            semantic_embeddings, dim=1
        )  # (N, M, D * K)

        return self.semantic_attention(semantic_embeddings)  # (N, D * K)
KounianhuaDu's avatar
KounianhuaDu committed
86
87


88
# Relational neighbor aggregation
KounianhuaDu's avatar
KounianhuaDu committed
89
90
91
92
93
94
class RelationalAGG(nn.Module):
    def __init__(self, g, in_size, out_size, dropout=0.1):
        super(RelationalAGG, self).__init__()
        self.in_size = in_size
        self.out_size = out_size

95
96
97
98
99
100
101
        # Transform weights for different types of edges
        self.W_T = nn.ModuleDict(
            {
                name: nn.Linear(in_size, out_size, bias=False)
                for name in g.etypes
            }
        )
KounianhuaDu's avatar
KounianhuaDu committed
102

103
104
105
106
        # Attention weights for different types of edges
        self.W_A = nn.ModuleDict(
            {name: nn.Linear(out_size, 1, bias=False) for name in g.etypes}
        )
KounianhuaDu's avatar
KounianhuaDu committed
107

108
        # layernorm
KounianhuaDu's avatar
KounianhuaDu committed
109
110
        self.layernorm = nn.LayerNorm(out_size)

111
        # dropout layer
KounianhuaDu's avatar
KounianhuaDu committed
112
113
114
        self.dropout = nn.Dropout(dropout)

    def forward(self, g, feat_dict):
115
        funcs = {}
KounianhuaDu's avatar
KounianhuaDu committed
116
        for srctype, etype, dsttype in g.canonical_etypes:
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            g.nodes[dsttype].data["h"] = feat_dict[
                dsttype
            ]  # nodes' original feature
            g.nodes[srctype].data["h"] = feat_dict[srctype]
            g.nodes[srctype].data["t_h"] = self.W_T[etype](
                feat_dict[srctype]
            )  # src nodes' transformed feature

            # compute the attention numerator (exp)
            g.apply_edges(fn.u_mul_v("t_h", "h", "x"), etype=etype)
            g.edges[etype].data["x"] = torch.exp(
                self.W_A[etype](g.edges[etype].data["x"])
            )

            # first update to compute the attention denominator (\sum exp)
            funcs[etype] = (fn.copy_e("x", "m"), fn.sum("m", "att"))
        g.multi_update_all(funcs, "sum")

        funcs = {}
KounianhuaDu's avatar
KounianhuaDu committed
136
        for srctype, etype, dsttype in g.canonical_etypes:
137
138
139
140
141
142
143
144
145
146
147
148
            g.apply_edges(
                fn.e_div_v("x", "att", "att"), etype=etype
            )  # compute attention weights (numerator/denominator)
            funcs[etype] = (
                fn.u_mul_e("h", "att", "m"),
                fn.sum("m", "h"),
            )  # \sum(h0*att) -> h1
        # second update to obtain h1
        g.multi_update_all(funcs, "sum")

        # apply activation, layernorm, and dropout
        feat_dict = {}
KounianhuaDu's avatar
KounianhuaDu committed
149
        for ntype in g.ntypes:
150
151
152
153
            feat_dict[ntype] = self.dropout(
                self.layernorm(F.relu_(g.nodes[ntype].data["h"]))
            )  # apply activation, layernorm, and dropout

KounianhuaDu's avatar
KounianhuaDu committed
154
155
        return feat_dict

156

KounianhuaDu's avatar
KounianhuaDu committed
157
class TAHIN(nn.Module):
158
159
160
    def __init__(
        self, g, meta_path_patterns, in_size, out_size, num_heads, dropout
    ):
KounianhuaDu's avatar
KounianhuaDu committed
161
162
        super(TAHIN, self).__init__()

163
        # embeddings for different types of nodes, h0
KounianhuaDu's avatar
KounianhuaDu committed
164
        self.initializer = nn.init.xavier_uniform_
165
166
167
168
169
170
171
172
        self.feature_dict = nn.ParameterDict(
            {
                ntype: nn.Parameter(
                    self.initializer(torch.empty(g.num_nodes(ntype), in_size))
                )
                for ntype in g.ntypes
            }
        )
KounianhuaDu's avatar
KounianhuaDu committed
173

174
        # relational neighbor aggregation, this produces h1
KounianhuaDu's avatar
KounianhuaDu committed
175
176
        self.RelationalAGG = RelationalAGG(g, in_size, out_size)

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        # metapath-based aggregation modules for user and item, this produces h2
        self.meta_path_patterns = meta_path_patterns
        # one HANLayer for user, one HANLayer for item
        self.hans = nn.ModuleDict(
            {
                key: HANLayer(value, in_size, out_size, num_heads, dropout)
                for key, value in self.meta_path_patterns.items()
            }
        )

        # layers to combine h0, h1, and h2
        # used to update node embeddings
        self.user_layer1 = nn.Linear(
            (num_heads + 1) * out_size, out_size, bias=True
        )
        self.user_layer2 = nn.Linear(2 * out_size, out_size, bias=True)
        self.item_layer1 = nn.Linear(
            (num_heads + 1) * out_size, out_size, bias=True
        )
        self.item_layer2 = nn.Linear(2 * out_size, out_size, bias=True)

        # layernorm
KounianhuaDu's avatar
KounianhuaDu committed
199
200
        self.layernorm = nn.LayerNorm(out_size)

201
        # network to score the node pairs
KounianhuaDu's avatar
KounianhuaDu committed
202
203
204
205
206
        self.pred = nn.Linear(out_size, out_size)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(out_size, 1)

    def forward(self, g, user_key, item_key, user_idx, item_idx):
207
        # relational neighbor aggregation, h1
KounianhuaDu's avatar
KounianhuaDu committed
208
209
        h1 = self.RelationalAGG(g, self.feature_dict)

210
        # metapath-based aggregation, h2
KounianhuaDu's avatar
KounianhuaDu committed
211
212
213
214
        h2 = {}
        for key in self.meta_path_patterns.keys():
            h2[key] = self.hans[key](g, self.feature_dict[key])

215
        # update node embeddings
KounianhuaDu's avatar
KounianhuaDu committed
216
217
218
219
        user_emb = torch.cat((h1[user_key], h2[user_key]), 1)
        item_emb = torch.cat((h1[item_key], h2[item_key]), 1)
        user_emb = self.user_layer1(user_emb)
        item_emb = self.item_layer1(item_emb)
220
221
222
223
224
225
        user_emb = self.user_layer2(
            torch.cat((user_emb, self.feature_dict[user_key]), 1)
        )
        item_emb = self.item_layer2(
            torch.cat((item_emb, self.feature_dict[item_key]), 1)
        )
KounianhuaDu's avatar
KounianhuaDu committed
226

227
        # Relu
KounianhuaDu's avatar
KounianhuaDu committed
228
229
        user_emb = F.relu_(user_emb)
        item_emb = F.relu_(item_emb)
230
231

        # layer norm
KounianhuaDu's avatar
KounianhuaDu committed
232
233
        user_emb = self.layernorm(user_emb)
        item_emb = self.layernorm(item_emb)
234
235

        # obtain users/items embeddings and their interactions
KounianhuaDu's avatar
KounianhuaDu committed
236
237
        user_feat = user_emb[user_idx]
        item_feat = item_emb[item_idx]
238
        interaction = user_feat * item_feat
KounianhuaDu's avatar
KounianhuaDu committed
239

240
        # score the node pairs
KounianhuaDu's avatar
KounianhuaDu committed
241
        pred = self.pred(interaction)
242
        pred = self.dropout(pred)  # dropout
KounianhuaDu's avatar
KounianhuaDu committed
243
244
245
246
        pred = self.fc(pred)
        pred = torch.sigmoid(pred)

        return pred.squeeze(1)