import sympy as sym import torch import torch.nn as nn from modules.basis_utils import bessel_basis, real_sph_harm from modules.envelope import Envelope class SphericalBasisLayer(nn.Module): def __init__(self, num_spherical, num_radial, cutoff, envelope_exponent=5): super(SphericalBasisLayer, self).__init__() assert num_radial <= 64 self.num_radial = num_radial self.num_spherical = num_spherical self.cutoff = cutoff self.envelope = Envelope(envelope_exponent) # retrieve formulas self.bessel_formulas = bessel_basis(num_spherical, num_radial) # x, [num_spherical, num_radial] sympy functions self.sph_harm_formulas = real_sph_harm(num_spherical) # theta, [num_spherical, ] sympy functions self.sph_funcs = [] self.bessel_funcs = [] # convert to torch functions x = sym.symbols('x') theta = sym.symbols('theta') modules = {'sin': torch.sin, 'cos': torch.cos} for i in range(num_spherical): if i == 0: first_sph = sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)(0) self.sph_funcs.append(lambda tensor: torch.zeros_like(tensor) + first_sph) else: self.sph_funcs.append(sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)) for j in range(num_radial): self.bessel_funcs.append(sym.lambdify([x], self.bessel_formulas[i][j], modules)) def get_bessel_funcs(self): return self.bessel_funcs def get_sph_funcs(self): return self.sph_funcs