Unverified Commit acd53a97 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/readability improvements (#13)

* Updated Changelog

* improved readability of legendre.py

* bugfix in computation of dlegpoly

* Moving conversion from numpy to torch from the legendre module to the sht module

* renaming to vdm in tests.py
parent 7e30c071
......@@ -5,6 +5,7 @@
### v0.6.3
* Adding gradient check in unit tests
* Temporary work-around for NCCL contiguous issues with distributed SHT
* Updated SFNO example
### v0.6.2
......
......@@ -36,8 +36,8 @@ import torch.nn as nn
import torch.fft
import torch.nn.functional as F
from torch_harmonics.quadrature import *
from torch_harmonics.legendre import *
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
......@@ -112,7 +112,8 @@ class DistributedRealSHT(nn.Module):
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights)
# we need to split in m, pad before:
......@@ -255,7 +256,8 @@ class DistributedInverseRealSHT(nn.Module):
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute legende polynomials
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)
# split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
......@@ -404,7 +406,8 @@ class DistributedRealVectorSHT(nn.Module):
self.mpad = mdist * self.comm_size_azimuth - self.mmax
weights = torch.from_numpy(w)
dpct = precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
......@@ -566,11 +569,12 @@ class DistributedInverseRealVectorSHT(nn.Module):
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute legende polynomials
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
pct = torch.split(pct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
dpct = F.pad(dpct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
dpct = torch.split(dpct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
# register buffer
self.register_buffer('dpct', dpct, persistent=False)
......
......@@ -30,7 +30,6 @@
#
import numpy as np
import torch
def clm(l, m):
"""
......@@ -38,12 +37,11 @@ def clm(l, m):
"""
return np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m))
def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True):
def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True):
r"""
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by x (theta)
The resulting tensor has shape (mmax, lmax, len(x)).
The Condon-Shortley Phase (-1)^m can be turned off optionally
Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.
method of computation follows
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
......@@ -54,57 +52,69 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
# compute the tensor P^m_n:
nmax = max(mmax,lmax)
pct = np.zeros((nmax, nmax, len(t)), dtype=np.float64)
sint = np.sin(t)
cost = np.cos(t)
vdm = np.zeros((nmax, nmax, len(x)), dtype=np.float64)
norm_factor = 1. if norm == "ortho" else np.sqrt(4 * np.pi)
norm_factor = 1. / norm_factor if inverse else norm_factor
# initial values to start the recursion
pct[0,0,:] = norm_factor / np.sqrt(4 * np.pi)
vdm[0,0,:] = norm_factor / np.sqrt(4 * np.pi)
# fill the diagonal and the lower diagonal
for l in range(1, nmax):
pct[l-1, l, :] = np.sqrt(2*l + 1) * cost * pct[l-1, l-1, :]
pct[l, l, :] = np.sqrt( (2*l + 1) * (1 + cost) * (1 - cost) / 2 / l ) * pct[l-1, l-1, :]
vdm[l-1, l, :] = np.sqrt(2*l + 1) * x * vdm[l-1, l-1, :]
vdm[l, l, :] = np.sqrt( (2*l + 1) * (1 + x) * (1 - x) / 2 / l ) * vdm[l-1, l-1, :]
# fill the remaining values on the upper triangle and multiply b
for l in range(2, nmax):
for m in range(0, l-1):
pct[m, l, :] = cost * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * pct[m, l-1, :] \
- np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * pct[m, l-2, :]
vdm[m, l, :] = x * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * vdm[m, l-1, :] \
- np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * vdm[m, l-2, :]
if norm == "schmidt":
for l in range(0, nmax):
if inverse:
pct[:, l, : ] = pct[:, l, : ] * np.sqrt(2*l + 1)
vdm[:, l, : ] = vdm[:, l, : ] * np.sqrt(2*l + 1)
else:
pct[:, l, : ] = pct[:, l, : ] / np.sqrt(2*l + 1)
vdm[:, l, : ] = vdm[:, l, : ] / np.sqrt(2*l + 1)
pct = pct[:mmax, :lmax]
vdm = vdm[:mmax, :lmax]
if csphase:
for m in range(1, mmax, 2):
pct[m] *= -1
vdm[m] *= -1
return vdm
def _precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True):
r"""
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta).
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.
method of computation follows
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982;
https://apps.dtic.mil/sti/citations/ADA123406
[3] Schrama, E.; Orbit integration based upon interpolated gravitational gradients
"""
return torch.from_numpy(pct)
return legpoly(mmax, lmax, np.cos(t), norm=norm, inverse=inverse, csphase=csphase)
def precompute_dlegpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True):
def _precompute_dlegpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True):
r"""
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
at the positions specified by x (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
needed for the computation of the vector spherical harmonics. The resulting tensor has shape
(2, mmax, lmax, len(x)).
(2, mmax, lmax, len(t)).
computation follows
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
pct = precompute_legpoly(mmax+1, lmax+1, x, norm=norm, inverse=inverse, csphase=False)
pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False)
dpct = torch.zeros((2, mmax, lmax, len(x)), dtype=torch.float64)
dpct = np.zeros((2, mmax, lmax, len(t)), dtype=np.float64)
# fill the derivative terms wrt theta
for l in range(0, lmax):
......
......@@ -34,8 +34,8 @@ import torch
import torch.nn as nn
import torch.fft
from .quadrature import *
from .legendre import *
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
class RealSHT(nn.Module):
......@@ -90,7 +90,8 @@ class RealSHT(nn.Module):
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights)
# remember quadrature weights
......@@ -166,7 +167,8 @@ class InverseRealSHT(nn.Module):
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)
# register buffer
self.register_buffer('pct', pct, persistent=False)
......@@ -245,7 +247,8 @@ class RealVectorSHT(nn.Module):
self.mmax = mmax or self.nlon // 2 + 1
weights = torch.from_numpy(w)
dpct = precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
......@@ -337,7 +340,8 @@ class InverseRealVectorSHT(nn.Module):
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# register weights
self.register_buffer('dpct', dpct, persistent=False)
......
......@@ -69,14 +69,14 @@ class TestLegendrePolynomials(unittest.TestCase):
def test_legendre(self):
print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import precompute_legpoly
from torch_harmonics.legendre import legpoly
t = np.linspace(0, np.pi, 100)
pct = precompute_legpoly(self.mmax, self.lmax, t)
t = np.linspace(0, 1, 100)
vdm = legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax):
for m in range(l+1):
diff = pct[m, l].numpy() / self.cml(m,l) - self.pml[(m,l)](np.cos(t))
diff = vdm[m, l] / self.cml(m,l) - self.pml[(m,l)](t)
self.assertTrue(diff.max() <= self.tol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment