layers.py 3.62 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl.function as fn
xnouhz's avatar
xnouhz committed
2
3
4
5
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.functional import edge_softmax
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
from modules import MessageNorm, MLP
from ogb.graphproppred.mol_encoder import BondEncoder
xnouhz's avatar
xnouhz committed
8
9
10
11


class GENConv(nn.Module):
    r"""
12

xnouhz's avatar
xnouhz committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    Description
    -----------
    Generalized Message Aggregator was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"

    Parameters
    ----------
    in_dim: int
        Input size.
    out_dim: int
        Output size.
    aggregator: str
        Type of aggregation. Default is 'softmax'.
    beta: float
        A continuous variable called an inverse temperature. Default is 1.0.
    learn_beta: bool
        Whether beta is a learnable variable or not. Default is False.
    p: float
        Initial power for power mean aggregation. Default is 1.0.
    learn_p: bool
        Whether p is a learnable variable or not. Default is False.
    msg_norm: bool
        Whether message normalization is used. Default is False.
    learn_msg_scale: bool
        Whether s is a learnable scaling factor or not in message normalization. Default is False.
    mlp_layers: int
        The number of MLP layers. Default is 1.
    eps: float
        A small positive constant in message construction function. Default is 1e-7.
    """
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

    def __init__(
        self,
        in_dim,
        out_dim,
        aggregator="softmax",
        beta=1.0,
        learn_beta=False,
        p=1.0,
        learn_p=False,
        msg_norm=False,
        learn_msg_scale=False,
        mlp_layers=1,
        eps=1e-7,
    ):
xnouhz's avatar
xnouhz committed
57
        super(GENConv, self).__init__()
58

xnouhz's avatar
xnouhz committed
59
60
61
62
63
64
65
66
67
68
69
        self.aggr = aggregator
        self.eps = eps

        channels = [in_dim]
        for _ in range(mlp_layers - 1):
            channels.append(in_dim * 2)
        channels.append(out_dim)

        self.mlp = MLP(channels)
        self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None

70
71
72
73
74
75
76
77
78
79
        self.beta = (
            nn.Parameter(torch.Tensor([beta]), requires_grad=True)
            if learn_beta and self.aggr == "softmax"
            else beta
        )
        self.p = (
            nn.Parameter(torch.Tensor([p]), requires_grad=True)
            if learn_p
            else p
        )
xnouhz's avatar
xnouhz committed
80
81
82
83
84
85

        self.edge_encoder = BondEncoder(in_dim)

    def forward(self, g, node_feats, edge_feats):
        with g.local_scope():
            # Node and edge feature size need to match.
86
87
88
89
90
91
92
93
94
95
96
97
98
            g.ndata["h"] = node_feats
            g.edata["h"] = self.edge_encoder(edge_feats)
            g.apply_edges(fn.u_add_e("h", "h", "m"))

            if self.aggr == "softmax":
                g.edata["m"] = F.relu(g.edata["m"]) + self.eps
                g.edata["a"] = edge_softmax(g, g.edata["m"] * self.beta)
                g.update_all(
                    lambda edge: {"x": edge.data["m"] * edge.data["a"]},
                    fn.sum("x", "m"),
                )

            elif self.aggr == "power":
xnouhz's avatar
xnouhz committed
99
                minv, maxv = 1e-7, 1e1
100
101
102
103
104
105
106
107
                torch.clamp_(g.edata["m"], minv, maxv)
                g.update_all(
                    lambda edge: {"x": torch.pow(edge.data["m"], self.p)},
                    fn.mean("x", "m"),
                )
                torch.clamp_(g.ndata["m"], minv, maxv)
                g.ndata["m"] = torch.pow(g.ndata["m"], self.p)

xnouhz's avatar
xnouhz committed
108
            else:
109
110
111
112
                raise NotImplementedError(
                    f"Aggregator {self.aggr} is not supported."
                )

xnouhz's avatar
xnouhz committed
113
            if self.msg_norm is not None:
114
115
116
117
                g.ndata["m"] = self.msg_norm(node_feats, g.ndata["m"])

            feats = node_feats + g.ndata["m"]

xnouhz's avatar
xnouhz committed
118
            return self.mlp(feats)