output_pp_block.py 1.83 KB
Newer Older
1
2
import dgl
import dgl.function as fn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
5
6
7


class OutputPPBlock(nn.Module):
8
9
10
11
12
13
14
15
16
17
18
    def __init__(
        self,
        emb_size,
        out_emb_size,
        num_radial,
        num_dense,
        num_targets,
        activation=None,
        output_init=nn.init.zeros_,
        extensive=True,
    ):
19
20
21
22
23
24
25
        super(OutputPPBlock, self).__init__()

        self.activation = activation
        self.output_init = output_init
        self.extensive = extensive
        self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
        self.up_projection = nn.Linear(emb_size, out_emb_size, bias=False)
26
27
28
        self.dense_layers = nn.ModuleList(
            [nn.Linear(out_emb_size, out_emb_size) for _ in range(num_dense)]
        )
29
30
        self.dense_final = nn.Linear(out_emb_size, num_targets, bias=False)
        self.reset_params()
31

32
33
34
35
36
37
38
39
40
    def reset_params(self):
        GlorotOrthogonal(self.dense_rbf.weight)
        GlorotOrthogonal(self.up_projection.weight)
        for layer in self.dense_layers:
            GlorotOrthogonal(layer.weight)
        self.output_init(self.dense_final.weight)

    def forward(self, g):
        with g.local_scope():
41
            g.edata["tmp"] = g.edata["m"] * self.dense_rbf(g.edata["rbf"])
42
            g_reverse = dgl.reverse(g, copy_edata=True)
43
44
45
            g_reverse.update_all(fn.copy_e("tmp", "x"), fn.sum("x", "t"))
            g.ndata["t"] = self.up_projection(g_reverse.ndata["t"])

46
            for layer in self.dense_layers:
47
                g.ndata["t"] = layer(g.ndata["t"])
48
                if self.activation is not None:
49
50
51
52
53
                    g.ndata["t"] = self.activation(g.ndata["t"])
            g.ndata["t"] = self.dense_final(g.ndata["t"])
            return dgl.readout_nodes(
                g, "t", op="sum" if self.extensive else "mean"
            )