layer.py 2.64 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl
2
3
import torch
import torch.nn.functional as F
4
from dgl.nn import AvgPooling, GraphConv, MaxPooling
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
from utils import get_batch_id, topk
6
7
8


class SAGPool(torch.nn.Module):
9
    """The Self-Attention Pooling layer in paper
10
11
12
13
14
15
16
17
18
19
20
    `Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>`

    Args:
        in_dim (int): The dimension of node feature.
        ratio (float, optional): The pool ratio which determines the amount of nodes
            remain after pooling. (default: :obj:`0.5`)
        conv_op (torch.nn.Module, optional): The graph convolution layer in dgl used to
        compute scale for each node. (default: :obj:`dgl.nn.GraphConv`)
        non_linearity (Callable, optional): The non-linearity function, a pytorch function.
            (default: :obj:`torch.tanh`)
    """
21
22
23
24
25
26
27
28

    def __init__(
        self,
        in_dim: int,
        ratio=0.5,
        conv_op=GraphConv,
        non_linearity=torch.tanh,
    ):
29
30
31
32
33
        super(SAGPool, self).__init__()
        self.in_dim = in_dim
        self.ratio = ratio
        self.score_layer = conv_op(in_dim, 1)
        self.non_linearity = non_linearity
34
35

    def forward(self, graph: dgl.DGLGraph, feature: torch.Tensor):
36
        score = self.score_layer(graph, feature).squeeze()
37
38
39
40
41
42
        perm, next_batch_num_nodes = topk(
            score,
            self.ratio,
            get_batch_id(graph.batch_num_nodes()),
            graph.batch_num_nodes(),
        )
43
44
45
46
47
48
49
50
51
        feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1)
        graph = dgl.node_subgraph(graph, perm)

        # node_subgraph currently does not support batch-graph,
        # the 'batch_num_nodes' of the result subgraph is None.
        # So we manually set the 'batch_num_nodes' here.
        # Since global pooling has nothing to do with 'batch_num_edges',
        # we can leave it to be None or unchanged.
        graph.set_batch_num_nodes(next_batch_num_nodes)
52

53
54
55
56
57
58
59
        return graph, feature, perm


class ConvPoolBlock(torch.nn.Module):
    """A combination of GCN layer and SAGPool layer,
    followed by a concatenated (mean||sum) readout operation.
    """
60
61

    def __init__(self, in_dim: int, out_dim: int, pool_ratio=0.8):
62
63
64
65
        super(ConvPoolBlock, self).__init__()
        self.conv = GraphConv(in_dim, out_dim)
        self.pool = SAGPool(out_dim, ratio=pool_ratio)
        self.avgpool = AvgPooling()
66
67
        self.maxpool = MaxPooling()

68
69
70
    def forward(self, graph, feature):
        out = F.relu(self.conv(graph, feature))
        graph, out, _ = self.pool(graph, out)
71
72
73
        g_out = torch.cat(
            [self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1
        )
74
        return graph, out, g_out