layers.py 3.62 KB
Newer Older
xnouhz's avatar
xnouhz committed
1
2
3
import torch
import torch.nn as nn
import torch.nn.functional as F
4
from modules import MLP, MessageNorm
xnouhz's avatar
xnouhz committed
5
from ogb.graphproppred.mol_encoder import BondEncoder
6
7

import dgl.function as fn
xnouhz's avatar
xnouhz committed
8
9
10
11
12
from dgl.nn.functional import edge_softmax


class GENConv(nn.Module):
    r"""
13

xnouhz's avatar
xnouhz committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    Description
    -----------
    Generalized Message Aggregator was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"

    Parameters
    ----------
    in_dim: int
        Input size.
    out_dim: int
        Output size.
    aggregator: str
        Type of aggregation. Default is 'softmax'.
    beta: float
        A continuous variable called an inverse temperature. Default is 1.0.
    learn_beta: bool
        Whether beta is a learnable variable or not. Default is False.
    p: float
        Initial power for power mean aggregation. Default is 1.0.
    learn_p: bool
        Whether p is a learnable variable or not. Default is False.
    msg_norm: bool
        Whether message normalization is used. Default is False.
    learn_msg_scale: bool
        Whether s is a learnable scaling factor or not in message normalization. Default is False.
    mlp_layers: int
        The number of MLP layers. Default is 1.
    eps: float
        A small positive constant in message construction function. Default is 1e-7.
    """
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    def __init__(
        self,
        in_dim,
        out_dim,
        aggregator="softmax",
        beta=1.0,
        learn_beta=False,
        p=1.0,
        learn_p=False,
        msg_norm=False,
        learn_msg_scale=False,
        mlp_layers=1,
        eps=1e-7,
    ):
xnouhz's avatar
xnouhz committed
58
        super(GENConv, self).__init__()
59

xnouhz's avatar
xnouhz committed
60
61
62
63
64
65
66
67
68
69
70
        self.aggr = aggregator
        self.eps = eps

        channels = [in_dim]
        for _ in range(mlp_layers - 1):
            channels.append(in_dim * 2)
        channels.append(out_dim)

        self.mlp = MLP(channels)
        self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None

71
72
73
74
75
76
77
78
79
80
        self.beta = (
            nn.Parameter(torch.Tensor([beta]), requires_grad=True)
            if learn_beta and self.aggr == "softmax"
            else beta
        )
        self.p = (
            nn.Parameter(torch.Tensor([p]), requires_grad=True)
            if learn_p
            else p
        )
xnouhz's avatar
xnouhz committed
81
82
83
84
85
86

        self.edge_encoder = BondEncoder(in_dim)

    def forward(self, g, node_feats, edge_feats):
        with g.local_scope():
            # Node and edge feature size need to match.
87
88
89
90
91
92
93
94
95
96
97
98
99
            g.ndata["h"] = node_feats
            g.edata["h"] = self.edge_encoder(edge_feats)
            g.apply_edges(fn.u_add_e("h", "h", "m"))

            if self.aggr == "softmax":
                g.edata["m"] = F.relu(g.edata["m"]) + self.eps
                g.edata["a"] = edge_softmax(g, g.edata["m"] * self.beta)
                g.update_all(
                    lambda edge: {"x": edge.data["m"] * edge.data["a"]},
                    fn.sum("x", "m"),
                )

            elif self.aggr == "power":
xnouhz's avatar
xnouhz committed
100
                minv, maxv = 1e-7, 1e1
101
102
103
104
105
106
107
108
                torch.clamp_(g.edata["m"], minv, maxv)
                g.update_all(
                    lambda edge: {"x": torch.pow(edge.data["m"], self.p)},
                    fn.mean("x", "m"),
                )
                torch.clamp_(g.ndata["m"], minv, maxv)
                g.ndata["m"] = torch.pow(g.ndata["m"], self.p)

xnouhz's avatar
xnouhz committed
109
            else:
110
111
112
113
                raise NotImplementedError(
                    f"Aggregator {self.aggr} is not supported."
                )

xnouhz's avatar
xnouhz committed
114
            if self.msg_norm is not None:
115
116
117
118
                g.ndata["m"] = self.msg_norm(node_feats, g.ndata["m"])

            feats = node_feats + g.ndata["m"]

xnouhz's avatar
xnouhz committed
119
            return self.mlp(feats)