interaction_pp_block.py 4.31 KB
Newer Older
1
import torch.nn as nn
2
3
4
from modules.initializers import GlorotOrthogonal
from modules.residual_layer import ResidualLayer

5
6
7
8
9
import dgl
import dgl.function as fn


class InteractionPPBlock(nn.Module):
10
11
12
13
14
15
16
17
18
19
20
    def __init__(
        self,
        emb_size,
        int_emb_size,
        basis_emb_size,
        num_radial,
        num_spherical,
        num_before_skip,
        num_after_skip,
        activation=None,
    ):
21
22
23
24
25
26
        super(InteractionPPBlock, self).__init__()

        self.activation = activation
        # Transformations of Bessel and spherical basis representations
        self.dense_rbf1 = nn.Linear(num_radial, basis_emb_size, bias=False)
        self.dense_rbf2 = nn.Linear(basis_emb_size, emb_size, bias=False)
27
28
29
        self.dense_sbf1 = nn.Linear(
            num_radial * num_spherical, basis_emb_size, bias=False
        )
30
31
32
33
34
35
36
37
        self.dense_sbf2 = nn.Linear(basis_emb_size, int_emb_size, bias=False)
        # Dense transformations of input messages
        self.dense_ji = nn.Linear(emb_size, emb_size)
        self.dense_kj = nn.Linear(emb_size, emb_size)
        # Embedding projections for interaction triplets
        self.down_projection = nn.Linear(emb_size, int_emb_size, bias=False)
        self.up_projection = nn.Linear(int_emb_size, emb_size, bias=False)
        # Residual layers before skip connection
38
39
40
41
42
43
        self.layers_before_skip = nn.ModuleList(
            [
                ResidualLayer(emb_size, activation=activation)
                for _ in range(num_before_skip)
            ]
        )
44
45
        self.final_before_skip = nn.Linear(emb_size, emb_size)
        # Residual layers after skip connection
46
47
48
49
50
51
        self.layers_after_skip = nn.ModuleList(
            [
                ResidualLayer(emb_size, activation=activation)
                for _ in range(num_after_skip)
            ]
        )
52
53

        self.reset_params()
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def reset_params(self):
        GlorotOrthogonal(self.dense_rbf1.weight)
        GlorotOrthogonal(self.dense_rbf2.weight)
        GlorotOrthogonal(self.dense_sbf1.weight)
        GlorotOrthogonal(self.dense_sbf2.weight)
        GlorotOrthogonal(self.dense_ji.weight)
        nn.init.zeros_(self.dense_ji.bias)
        GlorotOrthogonal(self.dense_kj.weight)
        nn.init.zeros_(self.dense_kj.bias)
        GlorotOrthogonal(self.down_projection.weight)
        GlorotOrthogonal(self.up_projection.weight)

    def edge_transfer(self, edges):
        # Transform from Bessel basis to dense vector
69
        rbf = self.dense_rbf1(edges.data["rbf"])
70
71
        rbf = self.dense_rbf2(rbf)
        # Initial transformation
72
73
        x_ji = self.dense_ji(edges.data["m"])
        x_kj = self.dense_kj(edges.data["m"])
74
75
76
77
78
79
80
        if self.activation is not None:
            x_ji = self.activation(x_ji)
            x_kj = self.activation(x_kj)

        x_kj = self.down_projection(x_kj * rbf)
        if self.activation is not None:
            x_kj = self.activation(x_kj)
81
        return {"x_kj": x_kj, "x_ji": x_ji}
82
83

    def msg_func(self, edges):
84
        sbf = self.dense_sbf1(edges.data["sbf"])
85
        sbf = self.dense_sbf2(sbf)
86
87
        x_kj = edges.src["x_kj"] * sbf
        return {"x_kj": x_kj}
88
89
90

    def forward(self, g, l_g):
        g.apply_edges(self.edge_transfer)
91

92
93
94
95
96
97
        # nodes correspond to edges and edges correspond to nodes in the original graphs
        # node: d, rbf, o, rbf_env, x_kj, x_ji
        for k, v in g.edata.items():
            l_g.ndata[k] = v

        l_g_reverse = dgl.reverse(l_g, copy_edata=True)
98
        l_g_reverse.update_all(self.msg_func, fn.sum("x_kj", "m_update"))
99

100
        g.edata["m_update"] = self.up_projection(l_g_reverse.ndata["m_update"])
101
        if self.activation is not None:
102
            g.edata["m_update"] = self.activation(g.edata["m_update"])
103
        # Transformations before skip connection
104
        g.edata["m_update"] = g.edata["m_update"] + g.edata["x_ji"]
105
        for layer in self.layers_before_skip:
106
107
            g.edata["m_update"] = layer(g.edata["m_update"])
        g.edata["m_update"] = self.final_before_skip(g.edata["m_update"])
108
        if self.activation is not None:
109
            g.edata["m_update"] = self.activation(g.edata["m_update"])
110
111

        # Skip connection
112
        g.edata["m"] = g.edata["m"] + g.edata["m_update"]
113
114
115

        # Transformations after skip connection
        for layer in self.layers_after_skip:
116
            g.edata["m"] = layer(g.edata["m"])
117

118
        return g