Unverified Commit acd53a97 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/readability improvements (#13)

* Updated Changelog

* improved readability of legendre.py

* bugfix in computation of dlegpoly

* Moving conversion from numpy to torch from the legendre module to the sht module

* renaming to vdm in tests.py
parent 7e30c071
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
### v0.6.3 ### v0.6.3
* Adding gradient check in unit tests * Adding gradient check in unit tests
* Temporary work-around for NCCL contiguous issues with distributed SHT
* Updated SFNO example * Updated SFNO example
### v0.6.2 ### v0.6.2
......
...@@ -36,8 +36,8 @@ import torch.nn as nn ...@@ -36,8 +36,8 @@ import torch.nn as nn
import torch.fft import torch.fft
import torch.nn.functional as F import torch.nn.functional as F
from torch_harmonics.quadrature import * from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import * from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
...@@ -112,7 +112,8 @@ class DistributedRealSHT(nn.Module): ...@@ -112,7 +112,8 @@ class DistributedRealSHT(nn.Module):
# combine quadrature weights with the legendre weights # combine quadrature weights with the legendre weights
weights = torch.from_numpy(w) weights = torch.from_numpy(w)
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights) weights = torch.einsum('mlk,k->mlk', pct, weights)
# we need to split in m, pad before: # we need to split in m, pad before:
...@@ -255,7 +256,8 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -255,7 +256,8 @@ class DistributedInverseRealSHT(nn.Module):
self.mpad = mdist * self.comm_size_azimuth - self.mmax self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute legende polynomials # compute legende polynomials
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)
# split in m # split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant") pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
...@@ -404,7 +406,8 @@ class DistributedRealVectorSHT(nn.Module): ...@@ -404,7 +406,8 @@ class DistributedRealVectorSHT(nn.Module):
self.mpad = mdist * self.comm_size_azimuth - self.mmax self.mpad = mdist * self.comm_size_azimuth - self.mmax
weights = torch.from_numpy(w) weights = torch.from_numpy(w)
dpct = precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# combine integration weights, normalization factor in to one: # combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax) l = torch.arange(0, self.lmax)
...@@ -566,11 +569,12 @@ class DistributedInverseRealVectorSHT(nn.Module): ...@@ -566,11 +569,12 @@ class DistributedInverseRealVectorSHT(nn.Module):
self.mpad = mdist * self.comm_size_azimuth - self.mmax self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute legende polynomials # compute legende polynomials
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# split in m # split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant") dpct = F.pad(dpct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
pct = torch.split(pct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth] dpct = torch.split(dpct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
# register buffer # register buffer
self.register_buffer('dpct', dpct, persistent=False) self.register_buffer('dpct', dpct, persistent=False)
......
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
# #
import numpy as np import numpy as np
import torch
def clm(l, m): def clm(l, m):
""" """
...@@ -38,12 +37,11 @@ def clm(l, m): ...@@ -38,12 +37,11 @@ def clm(l, m):
""" """
return np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m)) return np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m))
def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True):
def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True):
r""" r"""
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by x (theta) Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
The resulting tensor has shape (mmax, lmax, len(x)). The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
The Condon-Shortley Phase (-1)^m can be turned off optionally can be turned off optionally.
method of computation follows method of computation follows
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
...@@ -54,57 +52,69 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True) ...@@ -54,57 +52,69 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
# compute the tensor P^m_n: # compute the tensor P^m_n:
nmax = max(mmax,lmax) nmax = max(mmax,lmax)
pct = np.zeros((nmax, nmax, len(t)), dtype=np.float64) vdm = np.zeros((nmax, nmax, len(x)), dtype=np.float64)
sint = np.sin(t)
cost = np.cos(t)
norm_factor = 1. if norm == "ortho" else np.sqrt(4 * np.pi) norm_factor = 1. if norm == "ortho" else np.sqrt(4 * np.pi)
norm_factor = 1. / norm_factor if inverse else norm_factor norm_factor = 1. / norm_factor if inverse else norm_factor
# initial values to start the recursion # initial values to start the recursion
pct[0,0,:] = norm_factor / np.sqrt(4 * np.pi) vdm[0,0,:] = norm_factor / np.sqrt(4 * np.pi)
# fill the diagonal and the lower diagonal # fill the diagonal and the lower diagonal
for l in range(1, nmax): for l in range(1, nmax):
pct[l-1, l, :] = np.sqrt(2*l + 1) * cost * pct[l-1, l-1, :] vdm[l-1, l, :] = np.sqrt(2*l + 1) * x * vdm[l-1, l-1, :]
pct[l, l, :] = np.sqrt( (2*l + 1) * (1 + cost) * (1 - cost) / 2 / l ) * pct[l-1, l-1, :] vdm[l, l, :] = np.sqrt( (2*l + 1) * (1 + x) * (1 - x) / 2 / l ) * vdm[l-1, l-1, :]
# fill the remaining values on the upper triangle and multiply b # fill the remaining values on the upper triangle and multiply b
for l in range(2, nmax): for l in range(2, nmax):
for m in range(0, l-1): for m in range(0, l-1):
pct[m, l, :] = cost * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * pct[m, l-1, :] \ vdm[m, l, :] = x * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * vdm[m, l-1, :] \
- np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * pct[m, l-2, :] - np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * vdm[m, l-2, :]
if norm == "schmidt": if norm == "schmidt":
for l in range(0, nmax): for l in range(0, nmax):
if inverse: if inverse:
pct[:, l, : ] = pct[:, l, : ] * np.sqrt(2*l + 1) vdm[:, l, : ] = vdm[:, l, : ] * np.sqrt(2*l + 1)
else: else:
pct[:, l, : ] = pct[:, l, : ] / np.sqrt(2*l + 1) vdm[:, l, : ] = vdm[:, l, : ] / np.sqrt(2*l + 1)
pct = pct[:mmax, :lmax] vdm = vdm[:mmax, :lmax]
if csphase: if csphase:
for m in range(1, mmax, 2): for m in range(1, mmax, 2):
pct[m] *= -1 vdm[m] *= -1
return vdm
def _precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True):
r"""
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta).
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.
method of computation follows
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982;
https://apps.dtic.mil/sti/citations/ADA123406
[3] Schrama, E.; Orbit integration based upon interpolated gravitational gradients
"""
return torch.from_numpy(pct) return legpoly(mmax, lmax, np.cos(t), norm=norm, inverse=inverse, csphase=csphase)
def precompute_dlegpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True): def _precompute_dlegpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True):
r""" r"""
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$ Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
at the positions specified by x (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$, at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
needed for the computation of the vector spherical harmonics. The resulting tensor has shape needed for the computation of the vector spherical harmonics. The resulting tensor has shape
(2, mmax, lmax, len(x)). (2, mmax, lmax, len(t)).
computation follows computation follows
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
pct = precompute_legpoly(mmax+1, lmax+1, x, norm=norm, inverse=inverse, csphase=False) pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False)
dpct = torch.zeros((2, mmax, lmax, len(x)), dtype=torch.float64) dpct = np.zeros((2, mmax, lmax, len(t)), dtype=np.float64)
# fill the derivative terms wrt theta # fill the derivative terms wrt theta
for l in range(0, lmax): for l in range(0, lmax):
......
...@@ -34,8 +34,8 @@ import torch ...@@ -34,8 +34,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.fft import torch.fft
from .quadrature import * from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from .legendre import * from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
class RealSHT(nn.Module): class RealSHT(nn.Module):
...@@ -90,7 +90,8 @@ class RealSHT(nn.Module): ...@@ -90,7 +90,8 @@ class RealSHT(nn.Module):
# combine quadrature weights with the legendre weights # combine quadrature weights with the legendre weights
weights = torch.from_numpy(w) weights = torch.from_numpy(w)
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights) weights = torch.einsum('mlk,k->mlk', pct, weights)
# remember quadrature weights # remember quadrature weights
...@@ -166,7 +167,8 @@ class InverseRealSHT(nn.Module): ...@@ -166,7 +167,8 @@ class InverseRealSHT(nn.Module):
# determine the dimensions # determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)
# register buffer # register buffer
self.register_buffer('pct', pct, persistent=False) self.register_buffer('pct', pct, persistent=False)
...@@ -245,7 +247,8 @@ class RealVectorSHT(nn.Module): ...@@ -245,7 +247,8 @@ class RealVectorSHT(nn.Module):
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
weights = torch.from_numpy(w) weights = torch.from_numpy(w)
dpct = precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# combine integration weights, normalization factor in to one: # combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax) l = torch.arange(0, self.lmax)
...@@ -337,7 +340,8 @@ class InverseRealVectorSHT(nn.Module): ...@@ -337,7 +340,8 @@ class InverseRealVectorSHT(nn.Module):
# determine the dimensions # determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# register weights # register weights
self.register_buffer('dpct', dpct, persistent=False) self.register_buffer('dpct', dpct, persistent=False)
......
...@@ -69,14 +69,14 @@ class TestLegendrePolynomials(unittest.TestCase): ...@@ -69,14 +69,14 @@ class TestLegendrePolynomials(unittest.TestCase):
def test_legendre(self): def test_legendre(self):
print("Testing computation of associated Legendre polynomials") print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import precompute_legpoly from torch_harmonics.legendre import legpoly
t = np.linspace(0, np.pi, 100) t = np.linspace(0, 1, 100)
pct = precompute_legpoly(self.mmax, self.lmax, t) vdm = legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax): for l in range(self.lmax):
for m in range(l+1): for m in range(l+1):
diff = pct[m, l].numpy() / self.cml(m,l) - self.pml[(m,l)](np.cos(t)) diff = vdm[m, l] / self.cml(m,l) - self.pml[(m,l)](t)
self.assertTrue(diff.max() <= self.tol) self.assertTrue(diff.max() <= self.tol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment