import torch as th import torch.nn as nn from utils import local_global_loss_ from dgl.nn.pytorch import GraphConv from dgl.nn.pytorch.glob import SumPooling class MLP(nn.Module): def __init__(self, in_dim, out_dim): super(MLP, self).__init__() self.fcs = nn.Sequential( nn.Linear(in_dim, out_dim), nn.PReLU(), nn.Linear(out_dim, out_dim), nn.PReLU(), nn.Linear(out_dim, out_dim), nn.PReLU(), ) self.linear_shortcut = nn.Linear(in_dim, out_dim) def forward(self, x): return self.fcs(x) + self.linear_shortcut(x) class GCN(nn.Module): def __init__(self, in_dim, out_dim, num_layers, norm): super(GCN, self).__init__() self.num_layers = num_layers self.layers = nn.ModuleList() self.layers.append( GraphConv( in_dim, out_dim, bias=False, norm=norm, activation=nn.PReLU() ) ) self.pooling = SumPooling() for _ in range(num_layers - 1): self.layers.append( GraphConv( out_dim, out_dim, bias=False, norm=norm, activation=nn.PReLU(), ) ) def forward(self, graph, feat, edge_weight=None): h = self.layers[0](graph, feat, edge_weight=edge_weight) hg = self.pooling(graph, h) for idx in range(self.num_layers - 1): h = self.layers[idx + 1](graph, h, edge_weight=edge_weight) hg = th.cat((hg, self.pooling(graph, h)), -1) return h, hg class MVGRL(nn.Module): r""" mvgrl model Parameters ----------- in_dim: int Input feature size. out_dim: int Output feature size. num_layers: int Number of the GNN encoder layers. Functions ----------- forward(graph1, graph2, feat, edge_weight): graph1: DGLGraph The original graph graph2: DGLGraph The diffusion graph feat: tensor Node features edge_weight: tensor Edge weight of the diffusion graph """ def __init__(self, in_dim, out_dim, num_layers): super(MVGRL, self).__init__() self.local_mlp = MLP(out_dim, out_dim) self.global_mlp = MLP(num_layers * out_dim, out_dim) self.encoder1 = GCN(in_dim, out_dim, num_layers, norm="both") self.encoder2 = GCN(in_dim, out_dim, num_layers, norm="none") def get_embedding(self, graph1, graph2, feat, edge_weight): local_v1, global_v1 = self.encoder1(graph1, feat) local_v2, global_v2 = self.encoder2( graph2, feat, edge_weight=edge_weight ) global_v1 = self.global_mlp(global_v1) global_v2 = self.global_mlp(global_v2) return (global_v1 + global_v2).detach() def forward(self, graph1, graph2, feat, edge_weight, graph_id): # calculate node embeddings and graph embeddings local_v1, global_v1 = self.encoder1(graph1, feat) local_v2, global_v2 = self.encoder2( graph2, feat, edge_weight=edge_weight ) local_v1 = self.local_mlp(local_v1) local_v2 = self.local_mlp(local_v2) global_v1 = self.global_mlp(global_v1) global_v2 = self.global_mlp(global_v2) # calculate loss loss1 = local_global_loss_(local_v1, global_v2, graph_id) loss2 = local_global_loss_(local_v2, global_v1, graph_id) loss = loss1 + loss2 return loss