from typing import Optional import dgl import torch import torch.nn from dgl import DGLGraph from dgl.nn import GraphConv from torch import Tensor class GraphConvWithDropout(GraphConv): """ A GraphConv followed by a Dropout. """ def __init__( self, in_feats, out_feats, dropout=0.3, norm="both", weight=True, bias=True, activation=None, allow_zero_in_degree=False, ): super(GraphConvWithDropout, self).__init__( in_feats, out_feats, norm, weight, bias, activation, allow_zero_in_degree, ) self.dropout = torch.nn.Dropout(p=dropout) def call(self, graph, feat, weight=None): feat = self.dropout(feat) return super(GraphConvWithDropout, self).call(graph, feat, weight) class Discriminator(torch.nn.Module): """ Description ----------- A discriminator used to let the network to discrimate between positive (neighborhood of center node) and negative (any neighborhood in graph) samplings. Parameters ---------- feat_dim : int The number of channels of node features. """ def __init__(self, feat_dim: int): super(Discriminator, self).__init__() self.affine = torch.nn.Bilinear(feat_dim, feat_dim, 1) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.affine.weight) torch.nn.init.zeros_(self.affine.bias) def forward( self, h_x: Tensor, h_pos: Tensor, h_neg: Tensor, bias_pos: Optional[Tensor] = None, bias_neg: Optional[Tensor] = None, ): """ Parameters ---------- h_x : torch.Tensor Node features, shape: :obj:`(num_nodes, feat_dim)` h_pos : torch.Tensor The node features of positive samples It has the same shape as :obj:`h_x` h_neg : torch.Tensor The node features of negative samples It has the same shape as :obj:`h_x` bias_pos : torch.Tensor Bias parameter vector for positive scores shape: :obj:`(num_nodes)` bias_neg : torch.Tensor Bias parameter vector for negative scores shape: :obj:`(num_nodes)` Returns ------- (torch.Tensor, torch.Tensor) The output scores with shape (2 * num_nodes,), (num_nodes,) """ score_pos = self.affine(h_pos, h_x).squeeze() score_neg = self.affine(h_neg, h_x).squeeze() if bias_pos is not None: score_pos = score_pos + bias_pos if bias_neg is not None: score_neg = score_neg + bias_neg logits = torch.cat((score_pos, score_neg), 0) return logits, score_pos class DenseLayer(torch.nn.Module): """ Description ----------- Dense layer with a linear layer and an activation function """ def __init__( self, in_dim: int, out_dim: int, act: str = "prelu", bias=True ): super(DenseLayer, self).__init__() self.lin = torch.nn.Linear(in_dim, out_dim, bias=bias) self.act_type = act.lower() self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.lin.weight) if self.lin.bias is not None: torch.nn.init.zeros_(self.lin.bias) if self.act_type == "prelu": self.act = torch.nn.PReLU() else: self.act = torch.relu def forward(self, x): x = self.lin(x) return self.act(x) class IndexSelect(torch.nn.Module): """ Description ----------- The index selection layer used by VIPool Parameters ---------- pool_ratio : float The pooling ratio (for keeping nodes). For example, if `pool_ratio=0.8`, 80\% nodes will be preserved. hidden_dim : int The number of channels in node features. act : str, optional The activation function type. Default: :obj:`'prelu'` dist : int, optional DO NOT USE THIS PARAMETER """ def __init__( self, pool_ratio: float, hidden_dim: int, act: str = "prelu", dist: int = 1, ): super(IndexSelect, self).__init__() self.pool_ratio = pool_ratio self.dist = dist self.dense = DenseLayer(hidden_dim, hidden_dim, act) self.discriminator = Discriminator(hidden_dim) self.gcn = GraphConvWithDropout(hidden_dim, hidden_dim) def forward( self, graph: DGLGraph, h_pos: Tensor, h_neg: Tensor, bias_pos: Optional[Tensor] = None, bias_neg: Optional[Tensor] = None, ): """ Description ----------- Perform index selection Parameters ---------- graph : dgl.DGLGraph Input graph. h_pos : torch.Tensor The node features of positive samples It has the same shape as :obj:`h_x` h_neg : torch.Tensor The node features of negative samples It has the same shape as :obj:`h_x` bias_pos : torch.Tensor Bias parameter vector for positive scores shape: :obj:`(num_nodes)` bias_neg : torch.Tensor Bias parameter vector for negative scores shape: :obj:`(num_nodes)` """ # compute scores h_pos = self.dense(h_pos) h_neg = self.dense(h_neg) embed = self.gcn(graph, h_pos) h_center = torch.sigmoid(embed) logit, logit_pos = self.discriminator( h_center, h_pos, h_neg, bias_pos, bias_neg ) scores = torch.sigmoid(logit_pos) # sort scores scores, idx = torch.sort(scores, descending=True) # select top-k num_nodes = graph.num_nodes() num_select_nodes = int(self.pool_ratio * num_nodes) size_list = [num_select_nodes, num_nodes - num_select_nodes] select_scores, _ = torch.split(scores, size_list, dim=0) select_idx, non_select_idx = torch.split(idx, size_list, dim=0) return logit, select_scores, select_idx, non_select_idx, embed class GraphPool(torch.nn.Module): """ Description ----------- The pooling module for graph Parameters ---------- hidden_dim : int The number of channels of node features. use_gcn : bool, optional Whether use gcn in down sampling process. default: :obj:`False` """ def __init__(self, hidden_dim: int, use_gcn=False): super(GraphPool, self).__init__() self.use_gcn = use_gcn self.down_sample_gcn = ( GraphConvWithDropout(hidden_dim, hidden_dim) if use_gcn else None ) def forward( self, graph: DGLGraph, feat: Tensor, select_idx: Tensor, non_select_idx: Optional[Tensor] = None, scores: Optional[Tensor] = None, pool_graph=False, ): """ Description ----------- Perform graph pooling. Parameters ---------- graph : dgl.DGLGraph The input graph feat : torch.Tensor The input node feature select_idx : torch.Tensor The index in fine graph of node from coarse graph, this is obtained from previous graph pooling layers. non_select_idx : torch.Tensor, optional The index that not included in output graph. default: :obj:`None` scores : torch.Tensor, optional Scores for nodes used for pooling and scaling. default: :obj:`None` pool_graph : bool, optional Whether perform graph pooling on graph topology. default: :obj:`False` """ if self.use_gcn: feat = self.down_sample_gcn(graph, feat) feat = feat[select_idx] if scores is not None: feat = feat * scores.unsqueeze(-1) if pool_graph: num_node_batch = graph.batch_num_nodes() graph = dgl.node_subgraph(graph, select_idx) graph.set_batch_num_nodes(num_node_batch) return feat, graph else: return feat class GraphUnpool(torch.nn.Module): """ Description ----------- The unpooling module for graph Parameters ---------- hidden_dim : int The number of channels of node features. """ def __init__(self, hidden_dim: int): super(GraphUnpool, self).__init__() self.up_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim) def forward(self, graph: DGLGraph, feat: Tensor, select_idx: Tensor): """ Description ----------- Perform graph unpooling Parameters ---------- graph : dgl.DGLGraph The input graph feat : torch.Tensor The input node feature select_idx : torch.Tensor The index in fine graph of node from coarse graph, this is obtained from previous graph pooling layers. """ fine_feat = torch.zeros( (graph.num_nodes(), feat.size(-1)), device=feat.device ) fine_feat[select_idx] = feat fine_feat = self.up_sample_gcn(graph, fine_feat) return fine_feat