bessel_basis_layer.py 953 Bytes
Newer Older
1
2
3
4
5
import numpy as np
import torch
import torch.nn as nn
from modules.envelope import Envelope

6

7
class BesselBasisLayer(nn.Module):
8
    def __init__(self, num_radial, cutoff, envelope_exponent=5):
9
        super(BesselBasisLayer, self).__init__()
10

11
12
13
14
15
16
        self.cutoff = cutoff
        self.envelope = Envelope(envelope_exponent)
        self.frequencies = nn.Parameter(torch.Tensor(num_radial))
        self.reset_params()

    def reset_params(self):
17
        with torch.no_grad():
18
19
20
            torch.arange(
                1, self.frequencies.numel() + 1, out=self.frequencies
            ).mul_(np.pi)
21
        self.frequencies.requires_grad_()
22
23

    def forward(self, g):
24
        d_scaled = g.edata["d"] / self.cutoff
25
26
27
        # Necessary for proper broadcasting behaviour
        d_scaled = torch.unsqueeze(d_scaled, -1)
        d_cutoff = self.envelope(d_scaled)
28
        g.edata["rbf"] = d_cutoff * torch.sin(self.frequencies * d_scaled)
29
        return g