dimenet_pp.py 5.08 KB
Newer Older
1
2
3
4
5
6
import torch
import torch.nn as nn
from modules.activations import swish
from modules.bessel_basis_layer import BesselBasisLayer
from modules.embedding_block import EmbeddingBlock
from modules.interaction_pp_block import InteractionPPBlock
7
8
9
from modules.output_pp_block import OutputPPBlock
from modules.spherical_basis_layer import SphericalBasisLayer

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

class DimeNetPP(nn.Module):
    """
    DimeNet++ model.

    Parameters
    ----------
    emb_size
        Embedding size used for the messages
    out_emb_size
        Embedding size used for atoms in the output block
    int_emb_size
        Embedding size used for interaction triplets
    basis_emb_size
        Embedding size used inside the basis transformation
    num_blocks
        Number of building blocks to be stacked
    num_spherical
        Number of spherical harmonics
    num_radial
        Number of radial basis functions
    cutoff
        Cutoff distance for interatomic interactions
    envelope_exponent
        Shape of the smooth cutoff
    num_before_skip
        Number of residual layers in interaction block before skip connection
    num_after_skip
        Number of residual layers in interaction block after skip connection
    num_dense_output
        Number of dense layers for the output blocks
    num_targets
        Number of targets to predict
    activation
        Activation function
    extensive
        Whether the output should be extensive (proportional to the number of atoms)
    output_init
        Initial function in output block
    """
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

    def __init__(
        self,
        emb_size,
        out_emb_size,
        int_emb_size,
        basis_emb_size,
        num_blocks,
        num_spherical,
        num_radial,
        cutoff=5.0,
        envelope_exponent=5,
        num_before_skip=1,
        num_after_skip=2,
        num_dense_output=3,
        num_targets=12,
        activation=swish,
        extensive=True,
        output_init=nn.init.zeros_,
    ):
70
71
72
73
74
75
        super(DimeNetPP, self).__init__()

        self.num_blocks = num_blocks
        self.num_radial = num_radial

        # cosine basis function expansion layer
76
77
78
79
80
81
82
83
84
85
86
87
        self.rbf_layer = BesselBasisLayer(
            num_radial=num_radial,
            cutoff=cutoff,
            envelope_exponent=envelope_exponent,
        )

        self.sbf_layer = SphericalBasisLayer(
            num_spherical=num_spherical,
            num_radial=num_radial,
            cutoff=cutoff,
            envelope_exponent=envelope_exponent,
        )
88
89

        # embedding block
90
91
92
93
94
95
96
97
98
        self.emb_block = EmbeddingBlock(
            emb_size=emb_size,
            num_radial=num_radial,
            bessel_funcs=self.sbf_layer.get_bessel_funcs(),
            cutoff=cutoff,
            envelope_exponent=envelope_exponent,
            activation=activation,
        )

99
        # output block
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        self.output_blocks = nn.ModuleList(
            {
                OutputPPBlock(
                    emb_size=emb_size,
                    out_emb_size=out_emb_size,
                    num_radial=num_radial,
                    num_dense=num_dense_output,
                    num_targets=num_targets,
                    activation=activation,
                    extensive=extensive,
                    output_init=output_init,
                )
                for _ in range(num_blocks + 1)
            }
        )
115
116

        # interaction block
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        self.interaction_blocks = nn.ModuleList(
            {
                InteractionPPBlock(
                    emb_size=emb_size,
                    int_emb_size=int_emb_size,
                    basis_emb_size=basis_emb_size,
                    num_radial=num_radial,
                    num_spherical=num_spherical,
                    num_before_skip=num_before_skip,
                    num_after_skip=num_after_skip,
                    activation=activation,
                )
                for _ in range(num_blocks)
            }
        )

133
134
    def edge_init(self, edges):
        # Calculate angles k -> j -> i
135
        R1, R2 = edges.src["o"], edges.dst["o"]
136
137
138
139
140
141
142
143
144
        x = torch.sum(R1 * R2, dim=-1)
        y = torch.cross(R1, R2)
        y = torch.norm(y, dim=-1)
        angle = torch.atan2(y, x)
        # Transform via angles
        cbf = [f(angle) for f in self.sbf_layer.get_sph_funcs()]
        cbf = torch.stack(cbf, dim=1)  # [None, 7]
        cbf = cbf.repeat_interleave(self.num_radial, dim=1)  # [None, 42]
        # Notice: it's dst, not src
145
146
147
        sbf = edges.dst["rbf_env"] * cbf  # [None, 42]
        return {"sbf": sbf}

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    def forward(self, g, l_g):
        # add rbf features for each edge in one batch graph, [num_radial,]
        g = self.rbf_layer(g)
        # Embedding block
        g = self.emb_block(g)
        # Output block
        P = self.output_blocks[0](g)  # [batch_size, num_targets]
        # Prepare sbf feature before the following blocks
        for k, v in g.edata.items():
            l_g.ndata[k] = v

        l_g.apply_edges(self.edge_init)
        # Interaction blocks
        for i in range(self.num_blocks):
            g = self.interaction_blocks[i](g, l_g)
            P += self.output_blocks[i + 1](g)
164
165

        return P