TAHIN.py 7.64 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.function as fn
from dgl.nn.pytorch import GATConv

#Semantic attention in the metapath-based aggregation (the same as that in the HAN)
class SemanticAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128):
        super(SemanticAttention, self).__init__()

        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)
        )

    def forward(self, z):
        '''
        Shape of z: (N, M , D*K)
        N: number of nodes
        M: number of metapath patterns
        D: hidden_size
        K: number of heads
        '''
        w = self.project(z).mean(0)                    # (M, 1)
        beta = torch.softmax(w, dim=0)                 # (M, 1)
        beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)

        return (beta * z).sum(1)                       # (N, D * K)

#Metapath-based aggregation (the same as the HANLayer)
class HANLayer(nn.Module):
    def __init__(self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout):
        super(HANLayer, self).__init__()

        # One GAT layer for each meta path based adjacency matrix
        self.gat_layers = nn.ModuleList()
        for i in range(len(meta_path_patterns)):
            self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
                                           dropout, dropout, activation=F.elu,
                                           allow_zero_in_degree=True))
        self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
        self.meta_path_patterns = list(tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns)

        self._cached_graph = None
        self._cached_coalesced_graph = {}

    def forward(self, g, h):
        semantic_embeddings = []
        #obtain metapath reachable graph
        if self._cached_graph is None or self._cached_graph is not g:
            self._cached_graph = g
            self._cached_coalesced_graph.clear()
            for meta_path_pattern in self.meta_path_patterns:
                self._cached_coalesced_graph[meta_path_pattern] = dgl.metapath_reachable_graph(
                        g, meta_path_pattern)

        for i, meta_path_pattern in enumerate(self.meta_path_patterns):
            new_g = self._cached_coalesced_graph[meta_path_pattern]
            semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
        semantic_embeddings = torch.stack(semantic_embeddings, dim=1)                  # (N, M, D * K)

        return self.semantic_attention(semantic_embeddings)                            # (N, D * K)

#Relational neighbor aggregation
class RelationalAGG(nn.Module):
    def __init__(self, g, in_size, out_size, dropout=0.1):
        super(RelationalAGG, self).__init__()
        self.in_size = in_size
        self.out_size = out_size

        #Transform weights for different types of edges
        self.W_T = nn.ModuleDict({
            name : nn.Linear(in_size, out_size, bias = False) for name in g.etypes
        })

        #Attention weights for different types of edges
        self.W_A = nn.ModuleDict({
            name : nn.Linear(out_size, 1, bias = False) for name in g.etypes
        })

        #layernorm
        self.layernorm = nn.LayerNorm(out_size)

        #dropout layer
        self.dropout = nn.Dropout(dropout)

    def forward(self, g, feat_dict):
        funcs={}
        for srctype, etype, dsttype in g.canonical_etypes:
            g.nodes[dsttype].data['h'] = feat_dict[dsttype] #nodes' original feature
            g.nodes[srctype].data['h'] = feat_dict[srctype]  
            g.nodes[srctype].data['t_h'] = self.W_T[etype](feat_dict[srctype])  #src nodes' transformed feature

            #compute the attention numerator (exp)
            g.apply_edges(fn.u_mul_v('t_h','h','x'),etype=etype)
            g.edges[etype].data['x'] = torch.exp(self.W_A[etype](g.edges[etype].data['x']))

            #first update to compute the attention denominator (\sum exp)
            funcs[etype] = (fn.copy_e('x', 'm'), fn.sum('m', 'att'))  
        g.multi_update_all(funcs, 'sum') 

        funcs={}
        for srctype, etype, dsttype in g.canonical_etypes:
            g.apply_edges(fn.e_div_v('x', 'att', 'att'),etype=etype) #compute attention weights (numerator/denominator)
            funcs[etype] = (fn.u_mul_e('h', 'att', 'm'), fn.sum('m', 'h'))  #\sum(h0*att) -> h1
        #second update to obtain h1
        g.multi_update_all(funcs, 'sum') 

        #apply activation, layernorm, and dropout
        feat_dict={}
        for ntype in g.ntypes:
            feat_dict[ntype] = self.dropout(self.layernorm(F.relu_(g.nodes[ntype].data['h']))) #apply activation, layernorm, and dropout
        
        return feat_dict

class TAHIN(nn.Module):
    def __init__(self, g, meta_path_patterns, in_size, out_size, num_heads, dropout):
        super(TAHIN, self).__init__()

        #embeddings for different types of nodes, h0
        self.initializer = nn.init.xavier_uniform_
        self.feature_dict = nn.ParameterDict({
            ntype: nn.Parameter(self.initializer(torch.empty(g.num_nodes(ntype), in_size))) for ntype in g.ntypes
        })

        #relational neighbor aggregation, this produces h1
        self.RelationalAGG = RelationalAGG(g, in_size, out_size)

        #metapath-based aggregation modules for user and item, this produces h2
        self.meta_path_patterns = meta_path_patterns 
        #one HANLayer for user, one HANLayer for item
        self.hans = nn.ModuleDict({
            key: HANLayer(value, in_size, out_size, num_heads, dropout) for key, value in self.meta_path_patterns.items()
        })

        #layers to combine h0, h1, and h2
        #used to update node embeddings
        self.user_layer1 = nn.Linear((num_heads+1)*out_size, out_size, bias=True)
        self.user_layer2 = nn.Linear(2*out_size, out_size, bias=True)
        self.item_layer1 = nn.Linear((num_heads+1)*out_size, out_size, bias=True)
        self.item_layer2 = nn.Linear(2*out_size, out_size, bias=True)

        #layernorm
        self.layernorm = nn.LayerNorm(out_size)

        #network to score the node pairs
        self.pred = nn.Linear(out_size, out_size)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(out_size, 1)

    def forward(self, g, user_key, item_key, user_idx, item_idx):
        #relational neighbor aggregation, h1
        h1 = self.RelationalAGG(g, self.feature_dict)

        #metapath-based aggregation, h2
        h2 = {}
        for key in self.meta_path_patterns.keys():
            h2[key] = self.hans[key](g, self.feature_dict[key])

        #update node embeddings
        user_emb = torch.cat((h1[user_key], h2[user_key]), 1)
        item_emb = torch.cat((h1[item_key], h2[item_key]), 1)
        user_emb = self.user_layer1(user_emb)
        item_emb = self.item_layer1(item_emb)
        user_emb = self.user_layer2(torch.cat((user_emb, self.feature_dict[user_key]), 1))
        item_emb = self.item_layer2(torch.cat((item_emb, self.feature_dict[item_key]), 1))

        #Relu
        user_emb = F.relu_(user_emb)
        item_emb = F.relu_(item_emb)
        
        #layer norm
        user_emb = self.layernorm(user_emb)
        item_emb = self.layernorm(item_emb)
        
        #obtain users/items embeddings and their interactions
        user_feat = user_emb[user_idx]
        item_feat = item_emb[item_idx]
        interaction = user_feat*item_feat

        #score the node pairs
        pred = self.pred(interaction)
        pred = self.dropout(pred) #dropout
        pred = self.fc(pred)
        pred = torch.sigmoid(pred)

        return pred.squeeze(1)