import dgl import dgl.function as fn import torch import torch.nn as nn import torch.nn.functional as F from dgl.nn.pytorch import GATConv # Semantic attention in the metapath-based aggregation (the same as that in the HAN) class SemanticAttention(nn.Module): def __init__(self, in_size, hidden_size=128): super(SemanticAttention, self).__init__() self.project = nn.Sequential( nn.Linear(in_size, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 1, bias=False), ) def forward(self, z): """ Shape of z: (N, M , D*K) N: number of nodes M: number of metapath patterns D: hidden_size K: number of heads """ w = self.project(z).mean(0) # (M, 1) beta = torch.softmax(w, dim=0) # (M, 1) beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1) return (beta * z).sum(1) # (N, D * K) # Metapath-based aggregation (the same as the HANLayer) class HANLayer(nn.Module): def __init__( self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout ): super(HANLayer, self).__init__() # One GAT layer for each meta path based adjacency matrix self.gat_layers = nn.ModuleList() for i in range(len(meta_path_patterns)): self.gat_layers.append( GATConv( in_size, out_size, layer_num_heads, dropout, dropout, activation=F.elu, allow_zero_in_degree=True, ) ) self.semantic_attention = SemanticAttention( in_size=out_size * layer_num_heads ) self.meta_path_patterns = list( tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns ) self._cached_graph = None self._cached_coalesced_graph = {} def forward(self, g, h): semantic_embeddings = [] # obtain metapath reachable graph if self._cached_graph is None or self._cached_graph is not g: self._cached_graph = g self._cached_coalesced_graph.clear() for meta_path_pattern in self.meta_path_patterns: self._cached_coalesced_graph[ meta_path_pattern ] = dgl.metapath_reachable_graph(g, meta_path_pattern) for i, meta_path_pattern in enumerate(self.meta_path_patterns): new_g = self._cached_coalesced_graph[meta_path_pattern] semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1)) semantic_embeddings = torch.stack( semantic_embeddings, dim=1 ) # (N, M, D * K) return self.semantic_attention(semantic_embeddings) # (N, D * K) # Relational neighbor aggregation class RelationalAGG(nn.Module): def __init__(self, g, in_size, out_size, dropout=0.1): super(RelationalAGG, self).__init__() self.in_size = in_size self.out_size = out_size # Transform weights for different types of edges self.W_T = nn.ModuleDict( { name: nn.Linear(in_size, out_size, bias=False) for name in g.etypes } ) # Attention weights for different types of edges self.W_A = nn.ModuleDict( {name: nn.Linear(out_size, 1, bias=False) for name in g.etypes} ) # layernorm self.layernorm = nn.LayerNorm(out_size) # dropout layer self.dropout = nn.Dropout(dropout) def forward(self, g, feat_dict): funcs = {} for srctype, etype, dsttype in g.canonical_etypes: g.nodes[dsttype].data["h"] = feat_dict[ dsttype ] # nodes' original feature g.nodes[srctype].data["h"] = feat_dict[srctype] g.nodes[srctype].data["t_h"] = self.W_T[etype]( feat_dict[srctype] ) # src nodes' transformed feature # compute the attention numerator (exp) g.apply_edges(fn.u_mul_v("t_h", "h", "x"), etype=etype) g.edges[etype].data["x"] = torch.exp( self.W_A[etype](g.edges[etype].data["x"]) ) # first update to compute the attention denominator (\sum exp) funcs[etype] = (fn.copy_e("x", "m"), fn.sum("m", "att")) g.multi_update_all(funcs, "sum") funcs = {} for srctype, etype, dsttype in g.canonical_etypes: g.apply_edges( fn.e_div_v("x", "att", "att"), etype=etype ) # compute attention weights (numerator/denominator) funcs[etype] = ( fn.u_mul_e("h", "att", "m"), fn.sum("m", "h"), ) # \sum(h0*att) -> h1 # second update to obtain h1 g.multi_update_all(funcs, "sum") # apply activation, layernorm, and dropout feat_dict = {} for ntype in g.ntypes: feat_dict[ntype] = self.dropout( self.layernorm(F.relu_(g.nodes[ntype].data["h"])) ) # apply activation, layernorm, and dropout return feat_dict class TAHIN(nn.Module): def __init__( self, g, meta_path_patterns, in_size, out_size, num_heads, dropout ): super(TAHIN, self).__init__() # embeddings for different types of nodes, h0 self.initializer = nn.init.xavier_uniform_ self.feature_dict = nn.ParameterDict( { ntype: nn.Parameter( self.initializer(torch.empty(g.num_nodes(ntype), in_size)) ) for ntype in g.ntypes } ) # relational neighbor aggregation, this produces h1 self.RelationalAGG = RelationalAGG(g, in_size, out_size) # metapath-based aggregation modules for user and item, this produces h2 self.meta_path_patterns = meta_path_patterns # one HANLayer for user, one HANLayer for item self.hans = nn.ModuleDict( { key: HANLayer(value, in_size, out_size, num_heads, dropout) for key, value in self.meta_path_patterns.items() } ) # layers to combine h0, h1, and h2 # used to update node embeddings self.user_layer1 = nn.Linear( (num_heads + 1) * out_size, out_size, bias=True ) self.user_layer2 = nn.Linear(2 * out_size, out_size, bias=True) self.item_layer1 = nn.Linear( (num_heads + 1) * out_size, out_size, bias=True ) self.item_layer2 = nn.Linear(2 * out_size, out_size, bias=True) # layernorm self.layernorm = nn.LayerNorm(out_size) # network to score the node pairs self.pred = nn.Linear(out_size, out_size) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(out_size, 1) def forward(self, g, user_key, item_key, user_idx, item_idx): # relational neighbor aggregation, h1 h1 = self.RelationalAGG(g, self.feature_dict) # metapath-based aggregation, h2 h2 = {} for key in self.meta_path_patterns.keys(): h2[key] = self.hans[key](g, self.feature_dict[key]) # update node embeddings user_emb = torch.cat((h1[user_key], h2[user_key]), 1) item_emb = torch.cat((h1[item_key], h2[item_key]), 1) user_emb = self.user_layer1(user_emb) item_emb = self.item_layer1(item_emb) user_emb = self.user_layer2( torch.cat((user_emb, self.feature_dict[user_key]), 1) ) item_emb = self.item_layer2( torch.cat((item_emb, self.feature_dict[item_key]), 1) ) # Relu user_emb = F.relu_(user_emb) item_emb = F.relu_(item_emb) # layer norm user_emb = self.layernorm(user_emb) item_emb = self.layernorm(item_emb) # obtain users/items embeddings and their interactions user_feat = user_emb[user_idx] item_feat = item_emb[item_idx] interaction = user_feat * item_feat # score the node pairs pred = self.pred(interaction) pred = self.dropout(pred) # dropout pred = self.fc(pred) pred = torch.sigmoid(pred) return pred.squeeze(1)