spherical_basis_layer.py 1.77 KB
Newer Older
1
2
3
4
5
6
import sympy as sym
import torch
import torch.nn as nn
from modules.basis_utils import bessel_basis, real_sph_harm
from modules.envelope import Envelope

7

8
class SphericalBasisLayer(nn.Module):
9
    def __init__(self, num_spherical, num_radial, cutoff, envelope_exponent=5):
10
11
12
13
14
15
16
17
18
        super(SphericalBasisLayer, self).__init__()

        assert num_radial <= 64
        self.num_radial = num_radial
        self.num_spherical = num_spherical
        self.cutoff = cutoff
        self.envelope = Envelope(envelope_exponent)

        # retrieve formulas
19
20
21
22
23
24
        self.bessel_formulas = bessel_basis(
            num_spherical, num_radial
        )  # x, [num_spherical, num_radial] sympy functions
        self.sph_harm_formulas = real_sph_harm(
            num_spherical
        )  # theta, [num_spherical, ] sympy functions
25
26
27
28
        self.sph_funcs = []
        self.bessel_funcs = []

        # convert to torch functions
29
30
31
        x = sym.symbols("x")
        theta = sym.symbols("theta")
        modules = {"sin": torch.sin, "cos": torch.cos}
32
33
        for i in range(num_spherical):
            if i == 0:
34
35
36
37
38
39
                first_sph = sym.lambdify(
                    [theta], self.sph_harm_formulas[i][0], modules
                )(0)
                self.sph_funcs.append(
                    lambda tensor: torch.zeros_like(tensor) + first_sph
                )
40
            else:
41
42
43
                self.sph_funcs.append(
                    sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)
                )
44
            for j in range(num_radial):
45
46
47
                self.bessel_funcs.append(
                    sym.lambdify([x], self.bessel_formulas[i][j], modules)
                )
48
49
50
51
52

    def get_bessel_funcs(self):
        return self.bessel_funcs

    def get_sph_funcs(self):
53
        return self.sph_funcs