Unverified Commit cec07d7a authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/gradcheck (#9)

* Added gradient check to test suite
* reduced size of the unit test
* switched to parametrized for unittests
parent 17eefa53
name: Unit tests name: tests
on: [push] on: [push]
...@@ -22,5 +22,5 @@ jobs: ...@@ -22,5 +22,5 @@ jobs:
python -m pip install -e . python -m pip install -e .
- name: Test with pytest - name: Test with pytest
run: | run: |
pip install pytest pytest-cov python -m pip install pytest pytest-cov parameterized
pytest ./torch_harmonics/tests.py python -m pytest ./torch_harmonics/tests.py
\ No newline at end of file \ No newline at end of file
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
## Versioning ## Versioning
### v0.6.3
* Adding gradient check in unit tests
### v0.6.2 ### v0.6.2
* Adding github CI * Adding github CI
......
...@@ -43,11 +43,11 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -43,11 +43,11 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<!-- ## What is torch-harmonics? --> <!-- ## What is torch-harmonics? -->
`torch-harmonics` is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It was originally implemented to enable Spherical Fourier Neural Operators (SFNO). It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes. torch-harmonics is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It was originally implemented to enable Spherical Fourier Neural Operators (SFNO). It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes.
`torch-harmonics` uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed. torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
`torch-harmonics` has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators (SFNOs) [1]. torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators (SFNOs) [1].
<table border="0" cellspacing="0" cellpadding="0"> <table border="0" cellspacing="0" cellpadding="0">
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = '0.6.2' __version__ = '0.6.3'
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature from . import quadrature
......
...@@ -30,8 +30,11 @@ ...@@ -30,8 +30,11 @@
# #
import unittest import unittest
from parameterized import parameterized
import math
import numpy as np import numpy as np
import torch import torch
from torch.autograd import gradcheck
from torch_harmonics import * from torch_harmonics import *
# try: # try:
...@@ -44,7 +47,7 @@ tqdm = lambda x : x ...@@ -44,7 +47,7 @@ tqdm = lambda x : x
class TestLegendrePolynomials(unittest.TestCase): class TestLegendrePolynomials(unittest.TestCase):
def setUp(self): def setUp(self):
self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m)) self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l-m) / math.factorial(l+m))
self.pml = dict() self.pml = dict()
# preparing associated Legendre Polynomials (These include the Condon-Shortley phase) # preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
...@@ -62,28 +65,23 @@ class TestLegendrePolynomials(unittest.TestCase): ...@@ -62,28 +65,23 @@ class TestLegendrePolynomials(unittest.TestCase):
self.lmax = self.mmax = 4 self.lmax = self.mmax = 4
self.tol = 1e-9
def test_legendre(self): def test_legendre(self):
print("Testing computation of associated Legendre polynomials") print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import precompute_legpoly from torch_harmonics.legendre import precompute_legpoly
TOL = 1e-9
t = np.linspace(0, np.pi, 100) t = np.linspace(0, np.pi, 100)
pct = precompute_legpoly(self.mmax, self.lmax, t) pct = precompute_legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax): for l in range(self.lmax):
for m in range(l+1): for m in range(l+1):
diff = pct[m, l].numpy() / self.cml(m,l) - self.pml[(m,l)](np.cos(t)) diff = pct[m, l].numpy() / self.cml(m,l) - self.pml[(m,l)](np.cos(t))
self.assertTrue(diff.max() <= TOL) self.assertTrue(diff.max() <= self.tol)
print("done.")
class TestSphericalHarmonicTransform(unittest.TestCase): class TestSphericalHarmonicTransform(unittest.TestCase):
def __init__(self, testname, norm="ortho"):
super(TestSphericalHarmonicTransform, self).__init__(testname) # calling the super class init varies for different python versions. This works for 2.7
self.norm = norm
def setUp(self): def setUp(self):
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -93,76 +91,76 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -93,76 +91,76 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
print("Running test on CPU") print("Running test on CPU")
self.device = torch.device('cpu') self.device = torch.device('cpu')
self.batch_size = 128 @parameterized.expand([
self.nlat = 256 [256, 512, 32, "ortho", "equiangular", 1e-9],
self.nlon = 2*self.nlat [256, 512, 32, "ortho", "legendre-gauss", 1e-9],
[256, 512, 32, "four-pi", "equiangular", 1e-9],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9],
[256, 512, 32, "schmidt", "equiangular", 1e-9],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9],
])
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
def test_sht_leggauss(self):
print(f"Testing real-valued SHT on Legendre-Gauss grid with {self.norm} normalization")
TOL = 1e-9
testiters = [1, 2, 4, 8, 16] testiters = [1, 2, 4, 8, 16]
mmax = self.nlat if grid == "equiangular":
mmax = nlat // 2
else:
mmax = nlat
lmax = mmax lmax = mmax
sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="legendre-gauss", norm=self.norm).to(self.device) sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="legendre-gauss", norm=self.norm).to(self.device) isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
coeffs = torch.zeros(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) with torch.no_grad():
coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs) coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)
# testing error accumulation
for iter in testiters: for iter in testiters:
with self.subTest(i = iter): with self.subTest(i = iter):
print(f"{iter} iterations of batchsize {self.batch_size}:") print(f"{iter} iterations of batchsize {batch_size}:")
base = signal base = signal
for _ in tqdm(range(iter)): for _ in tqdm(range(iter)):
base = isht(sht(base)) base = isht(sht(base))
# err = ( torch.norm(base-self.signal, p='fro') / torch.norm(self.signal, p='fro') ).item() err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ).item() print(f"final relative error: {err.item()}")
print(f"final relative error: {err}") self.assertTrue(err.item() <= tol)
self.assertTrue(err <= TOL)
@parameterized.expand([
def test_sht_equiangular(self): [12, 24, 2, "ortho", "equiangular", 1e-5],
print(f"Testing real-valued SHT on equiangular grid with {self.norm} normalization") [12, 24, 2, "ortho", "legendre-gauss", 1e-5],
[12, 24, 2, "four-pi", "equiangular", 1e-5],
TOL = 1e-1 [12, 24, 2, "four-pi", "legendre-gauss", 1e-5],
testiters = [1, 2, 4, 8] [12, 24, 2, "schmidt", "equiangular", 1e-5],
mmax = self.nlat // 2 [12, 24, 2, "schmidt", "legendre-gauss", 1e-5],
])
def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
if grid == "equiangular":
mmax = nlat // 2
else:
mmax = nlat
lmax = mmax lmax = mmax
sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="equiangular", norm=self.norm).to(self.device) sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="equiangular", norm=self.norm).to(self.device) isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
coeffs = torch.zeros(self.batch_size, sht.lmax, sht.mmax, device=self.device, dtype=torch.complex128) with torch.no_grad():
coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs) coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)
for iter in testiters: input = torch.randn_like(signal, requires_grad=True)
with self.subTest(i = iter): err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
print(f"{iter} iterations of batchsize {self.batch_size}:") test_result = gradcheck(err_handle, input, eps=1e-6, atol=tol)
self.assertTrue(test_result)
base = signal
for _ in tqdm(range(iter)):
base = isht(sht(base))
# err = ( torch.norm(base-self.signal, p='fro') / torch.norm(self.signal, p='fro') ).item()
err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ).item()
print(f"final relative error: {err}")
self.assertTrue(err <= TOL)
if __name__ == '__main__': if __name__ == '__main__':
sht_test_suite = unittest.TestSuite() unittest.main()
sht_test_suite.addTest(TestLegendrePolynomials('test_legendre')) \ No newline at end of file
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="ortho"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="ortho"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="four-pi"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="four-pi"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="schmidt"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="schmidt"))
unittest.TextTestRunner(verbosity=2).run(sht_test_suite)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment