dimenet.py 4.58 KB
Newer Older
1
2
3
4
5
6
import torch
import torch.nn as nn
from modules.activations import swish
from modules.bessel_basis_layer import BesselBasisLayer
from modules.embedding_block import EmbeddingBlock
from modules.interaction_block import InteractionBlock
7
8
9
from modules.output_block import OutputBlock
from modules.spherical_basis_layer import SphericalBasisLayer

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

class DimeNet(nn.Module):
    """
    DimeNet model.

    Parameters
    ----------
    emb_size
        Embedding size used throughout the model
    num_blocks
        Number of building blocks to be stacked
    num_bilinear
        Third dimension of the bilinear layer tensor
    num_spherical
        Number of spherical harmonics
    num_radial
        Number of radial basis functions
    cutoff
        Cutoff distance for interatomic interactions
    envelope_exponent
        Shape of the smooth cutoff
    num_before_skip
        Number of residual layers in interaction block before skip connection
    num_after_skip
        Number of residual layers in interaction block after skip connection
    num_dense_output
        Number of dense layers for the output blocks
    num_targets
        Number of targets to predict
    activation
        Activation function
    output_init
        Initial function in output block
    """
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    def __init__(
        self,
        emb_size,
        num_blocks,
        num_bilinear,
        num_spherical,
        num_radial,
        cutoff=5.0,
        envelope_exponent=5,
        num_before_skip=1,
        num_after_skip=2,
        num_dense_output=3,
        num_targets=12,
        activation=swish,
        output_init=nn.init.zeros_,
    ):
61
62
63
64
65
66
        super(DimeNet, self).__init__()

        self.num_blocks = num_blocks
        self.num_radial = num_radial

        # cosine basis function expansion layer
67
68
69
70
71
72
73
74
75
76
77
78
        self.rbf_layer = BesselBasisLayer(
            num_radial=num_radial,
            cutoff=cutoff,
            envelope_exponent=envelope_exponent,
        )

        self.sbf_layer = SphericalBasisLayer(
            num_spherical=num_spherical,
            num_radial=num_radial,
            cutoff=cutoff,
            envelope_exponent=envelope_exponent,
        )
79
80

        # embedding block
81
82
83
84
85
86
87
88
89
        self.emb_block = EmbeddingBlock(
            emb_size=emb_size,
            num_radial=num_radial,
            bessel_funcs=self.sbf_layer.get_bessel_funcs(),
            cutoff=cutoff,
            envelope_exponent=envelope_exponent,
            activation=activation,
        )

90
        # output block
91
92
93
94
95
96
97
98
99
100
101
102
103
        self.output_blocks = nn.ModuleList(
            {
                OutputBlock(
                    emb_size=emb_size,
                    num_radial=num_radial,
                    num_dense=num_dense_output,
                    num_targets=num_targets,
                    activation=activation,
                    output_init=output_init,
                )
                for _ in range(num_blocks + 1)
            }
        )
104
105

        # interaction block
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        self.interaction_blocks = nn.ModuleList(
            {
                InteractionBlock(
                    emb_size=emb_size,
                    num_radial=num_radial,
                    num_spherical=num_spherical,
                    num_bilinear=num_bilinear,
                    num_before_skip=num_before_skip,
                    num_after_skip=num_after_skip,
                    activation=activation,
                )
                for _ in range(num_blocks)
            }
        )

121
122
    def edge_init(self, edges):
        # Calculate angles k -> j -> i
123
        R1, R2 = edges.src["o"], edges.dst["o"]
124
125
126
127
128
129
130
131
        x = torch.sum(R1 * R2, dim=-1)
        y = torch.cross(R1, R2)
        y = torch.norm(y, dim=-1)
        angle = torch.atan2(y, x)
        # Transform via angles
        cbf = [f(angle) for f in self.sbf_layer.get_sph_funcs()]
        cbf = torch.stack(cbf, dim=1)  # [None, 7]
        cbf = cbf.repeat_interleave(self.num_radial, dim=1)  # [None, 42]
132
133
134
        sbf = edges.src["rbf_env"] * cbf  # [None, 42]
        return {"sbf": sbf}

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    def forward(self, g, l_g):
        # add rbf features for each edge in one batch graph, [num_radial,]
        g = self.rbf_layer(g)
        # Embedding block
        g = self.emb_block(g)
        # Output block
        P = self.output_blocks[0](g)  # [batch_size, num_targets]
        # Prepare sbf feature before the following blocks
        for k, v in g.edata.items():
            l_g.ndata[k] = v

        l_g.apply_edges(self.edge_init)
        # Interaction blocks
        for i in range(self.num_blocks):
            g = self.interaction_blocks[i](g, l_g)
            P += self.output_blocks[i + 1](g)
151
152

        return P