Unverified Commit 4fea88bf authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/gradient test fix (#53)

* added analytic gradients to the gradient_analysis notebook

* fixing sht unittest to not check the roundtrip gradient but sht and isht individually

* Updated changelog
parent 60b3b5a2
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
* Added resampling modules for convenience * Added resampling modules for convenience
* Changing behavior of distributed SHT to use `dim=-3` as channel dimension * Changing behavior of distributed SHT to use `dim=-3` as channel dimension
* Fixing SHT unittests to test SHT and ISHT individually, rather than the roundtrip
### v0.7.1 ### v0.7.1
......
This diff is collapsed.
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# #
...@@ -37,24 +37,25 @@ import torch ...@@ -37,24 +37,25 @@ import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_harmonics import * from torch_harmonics import *
class TestLegendrePolynomials(unittest.TestCase): class TestLegendrePolynomials(unittest.TestCase):
def setUp(self): def setUp(self):
self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l-m) / math.factorial(l+m)) self.cml = lambda m, l: np.sqrt((2 * l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.pml = dict() self.pml = dict()
# preparing associated Legendre Polynomials (These include the Condon-Shortley phase) # preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
# for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials # for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials
self.pml[(0, 0)] = lambda x : np.ones_like(x) self.pml[(0, 0)] = lambda x: np.ones_like(x)
self.pml[(0, 1)] = lambda x : x self.pml[(0, 1)] = lambda x: x
self.pml[(1, 1)] = lambda x : - np.sqrt(1. - x**2) self.pml[(1, 1)] = lambda x: -np.sqrt(1.0 - x**2)
self.pml[(0, 2)] = lambda x : 0.5 * (3*x**2 - 1) self.pml[(0, 2)] = lambda x: 0.5 * (3 * x**2 - 1)
self.pml[(1, 2)] = lambda x : - 3 * x * np.sqrt(1. - x**2) self.pml[(1, 2)] = lambda x: -3 * x * np.sqrt(1.0 - x**2)
self.pml[(2, 2)] = lambda x : 3 * (1 - x**2) self.pml[(2, 2)] = lambda x: 3 * (1 - x**2)
self.pml[(0, 3)] = lambda x : 0.5 * (5*x**3 - 3*x) self.pml[(0, 3)] = lambda x: 0.5 * (5 * x**3 - 3 * x)
self.pml[(1, 3)] = lambda x : 1.5 * (1 - 5*x**2) * np.sqrt(1. - x**2) self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * np.sqrt(1.0 - x**2)
self.pml[(2, 3)] = lambda x : 15 * x * (1 - x**2) self.pml[(2, 3)] = lambda x: 15 * x * (1 - x**2)
self.pml[(3, 3)] = lambda x : -15 * np.sqrt(1. - x**2)**3 self.pml[(3, 3)] = lambda x: -15 * np.sqrt(1.0 - x**2) ** 3
self.lmax = self.mmax = 4 self.lmax = self.mmax = 4
...@@ -68,8 +69,8 @@ class TestLegendrePolynomials(unittest.TestCase): ...@@ -68,8 +69,8 @@ class TestLegendrePolynomials(unittest.TestCase):
vdm = legpoly(self.mmax, self.lmax, t) vdm = legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax): for l in range(self.lmax):
for m in range(l+1): for m in range(l + 1):
diff = vdm[m, l] / self.cml(m,l) - self.pml[(m,l)](t) diff = vdm[m, l] / self.cml(m, l) - self.pml[(m, l)](t)
self.assertTrue(diff.max() <= self.tol) self.assertTrue(diff.max() <= self.tol)
...@@ -79,19 +80,21 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -79,19 +80,21 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
if torch.cuda.is_available(): if torch.cuda.is_available():
print("Running test on GPU") print("Running test on GPU")
self.device = torch.device('cuda') self.device = torch.device("cuda")
else: else:
print("Running test on CPU") print("Running test on CPU")
self.device = torch.device('cpu') self.device = torch.device("cpu")
@parameterized.expand([ @parameterized.expand(
[256, 512, 32, "ortho", "equiangular", 1e-9], [
[256, 512, 32, "ortho", "legendre-gauss", 1e-9], [256, 512, 32, "ortho", "equiangular", 1e-9],
[256, 512, 32, "four-pi", "equiangular", 1e-9], [256, 512, 32, "ortho", "legendre-gauss", 1e-9],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9], [256, 512, 32, "four-pi", "equiangular", 1e-9],
[256, 512, 32, "schmidt", "equiangular", 1e-9], [256, 512, 32, "four-pi", "legendre-gauss", 1e-9],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9], [256, 512, 32, "schmidt", "equiangular", 1e-9],
]) [256, 512, 32, "schmidt", "legendre-gauss", 1e-9],
]
)
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol): def test_sht(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization") print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
...@@ -109,30 +112,38 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -109,30 +112,38 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs) signal = isht(coeffs)
# testing error accumulation # testing error accumulation
for iter in testiters: for iter in testiters:
with self.subTest(i = iter): with self.subTest(i=iter):
print(f"{iter} iterations of batchsize {batch_size}:") print(f"{iter} iterations of batchsize {batch_size}:")
base = signal base = signal
for _ in range(iter): for _ in range(iter):
base = isht(sht(base)) base = isht(sht(base))
err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ) err = torch.mean(torch.norm(base - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
print(f"final relative error: {err.item()}") print(f"final relative error: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
@parameterized.expand([ @parameterized.expand(
[12, 24, 2, "ortho", "equiangular", 1e-5], [
[12, 24, 2, "ortho", "legendre-gauss", 1e-5], [12, 24, 2, "ortho", "equiangular", 1e-5],
[12, 24, 2, "four-pi", "equiangular", 1e-5], [12, 24, 2, "ortho", "legendre-gauss", 1e-5],
[12, 24, 2, "four-pi", "legendre-gauss", 1e-5], [12, 24, 2, "four-pi", "equiangular", 1e-5],
[12, 24, 2, "schmidt", "equiangular", 1e-5], [12, 24, 2, "four-pi", "legendre-gauss", 1e-5],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5], [12, 24, 2, "schmidt", "equiangular", 1e-5],
]) [12, 24, 2, "schmidt", "legendre-gauss", 1e-5],
def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol): [15, 30, 2, "ortho", "equiangular", 1e-5],
[15, 30, 2, "ortho", "legendre-gauss", 1e-5],
[15, 30, 2, "four-pi", "equiangular", 1e-5],
[15, 30, 2, "four-pi", "legendre-gauss", 1e-5],
[15, 30, 2, "schmidt", "equiangular", 1e-5],
[15, 30, 2, "schmidt", "legendre-gauss", 1e-5],
]
)
def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization") print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
if grid == "equiangular": if grid == "equiangular":
...@@ -148,12 +159,19 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -148,12 +159,19 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs) signal = isht(coeffs)
# test the sht
grad_input = torch.randn_like(signal, requires_grad=True) grad_input = torch.randn_like(signal, requires_grad=True)
err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ) err_handle = lambda x: torch.mean(torch.norm(sht(x) - coeffs, p="fro", dim=(-1, -2)) / torch.norm(coeffs, p="fro", dim=(-1, -2)))
test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
self.assertTrue(test_result)
# test the isht
grad_input = torch.randn_like(coeffs, requires_grad=True)
err_handle = lambda x: torch.mean(torch.norm(isht(x) - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol) test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
self.assertTrue(test_result) self.assertTrue(test_result)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment