interaction_block.py 3.78 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl.function as fn
2
3
import torch
import torch.nn as nn
4
5
6
from modules.initializers import GlorotOrthogonal
from modules.residual_layer import ResidualLayer

7
8

class InteractionBlock(nn.Module):
9
10
11
12
13
14
15
16
17
18
    def __init__(
        self,
        emb_size,
        num_radial,
        num_spherical,
        num_bilinear,
        num_before_skip,
        num_after_skip,
        activation=None,
    ):
19
20
21
22
23
        super(InteractionBlock, self).__init__()

        self.activation = activation
        # Transformations of Bessel and spherical basis representations
        self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
24
25
26
        self.dense_sbf = nn.Linear(
            num_radial * num_spherical, num_bilinear, bias=False
        )
27
28
29
30
        # Dense transformations of input messages
        self.dense_ji = nn.Linear(emb_size, emb_size)
        self.dense_kj = nn.Linear(emb_size, emb_size)
        # Bilinear layer
31
32
33
        bilin_initializer = torch.empty(
            (emb_size, num_bilinear, emb_size)
        ).normal_(mean=0, std=2 / emb_size)
34
35
        self.W_bilin = nn.Parameter(bilin_initializer)
        # Residual layers before skip connection
36
37
38
39
40
41
        self.layers_before_skip = nn.ModuleList(
            [
                ResidualLayer(emb_size, activation=activation)
                for _ in range(num_before_skip)
            ]
        )
42
43
        self.final_before_skip = nn.Linear(emb_size, emb_size)
        # Residual layers after skip connection
44
45
46
47
48
49
        self.layers_after_skip = nn.ModuleList(
            [
                ResidualLayer(emb_size, activation=activation)
                for _ in range(num_after_skip)
            ]
        )
50
51

        self.reset_params()
52

53
54
55
56
57
58
59
60
61
    def reset_params(self):
        GlorotOrthogonal(self.dense_rbf.weight)
        GlorotOrthogonal(self.dense_sbf.weight)
        GlorotOrthogonal(self.dense_ji.weight)
        GlorotOrthogonal(self.dense_kj.weight)
        GlorotOrthogonal(self.final_before_skip.weight)

    def edge_transfer(self, edges):
        # Transform from Bessel basis to dence vector
62
        rbf = self.dense_rbf(edges.data["rbf"])
63
        # Initial transformation
64
65
        x_ji = self.dense_ji(edges.data["m"])
        x_kj = self.dense_kj(edges.data["m"])
66
67
68
69
70
        if self.activation is not None:
            x_ji = self.activation(x_ji)
            x_kj = self.activation(x_kj)

        # w: W * e_RBF \bigodot \sigma(W * m + b)
71
        return {"x_kj": x_kj * rbf, "x_ji": x_ji}
72
73

    def msg_func(self, edges):
74
        sbf = self.dense_sbf(edges.data["sbf"])
75
76
        # Apply bilinear layer to interactions and basis function activation
        # [None, 8] * [128, 8, 128] * [None, 128] -> [None, 128]
77
78
79
80
        x_kj = torch.einsum(
            "wj,wl,ijl->wi", sbf, edges.src["x_kj"], self.W_bilin
        )
        return {"x_kj": x_kj}
81
82
83

    def forward(self, g, l_g):
        g.apply_edges(self.edge_transfer)
84

85
86
87
88
89
        # nodes correspond to edges and edges correspond to nodes in the original graphs
        # node: d, rbf, o, rbf_env, x_kj, x_ji
        for k, v in g.edata.items():
            l_g.ndata[k] = v

90
        l_g.update_all(self.msg_func, fn.sum("x_kj", "m_update"))
91
92
93
94
95

        for k, v in l_g.ndata.items():
            g.edata[k] = v

        # Transformations before skip connection
96
        g.edata["m_update"] = g.edata["m_update"] + g.edata["x_ji"]
97
        for layer in self.layers_before_skip:
98
99
            g.edata["m_update"] = layer(g.edata["m_update"])
        g.edata["m_update"] = self.final_before_skip(g.edata["m_update"])
100
        if self.activation is not None:
101
            g.edata["m_update"] = self.activation(g.edata["m_update"])
102
103

        # Skip connection
104
        g.edata["m"] = g.edata["m"] + g.edata["m_update"]
105
106
107

        # Transformations after skip connection
        for layer in self.layers_after_skip:
108
            g.edata["m"] = layer(g.edata["m"])
109

110
        return g