Commit 4770621b authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

cleanup of tests

parent 30d8b2da
...@@ -45,15 +45,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -45,15 +45,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""
Set up the distributed convolution test.
Parameters
----------
cls : TestDistributedDiscreteContinuousConvolution
The test class instance
"""
# set up distributed # set up distributed
cls.world_rank = int(os.getenv("WORLD_RANK", 0)) cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1)) cls.grid_size_h = int(os.getenv("GRID_H", 1))
...@@ -216,41 +207,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -216,41 +207,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
def test_distributed_disco_conv( def test_distributed_disco_conv(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol
): ):
"""
Test the distributed discrete-continuous convolution module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
groups : int
Number of groups
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
"""
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
...@@ -285,9 +241,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -285,9 +241,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# create tensors # create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local conv # local conv
#############################################################
# FWD pass # FWD pass
inp_full.requires_grad = True inp_full.requires_grad = True
out_full = conv_local(inp_full) out_full = conv_local(inp_full)
...@@ -301,9 +255,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -301,9 +255,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_full.backward(ograd_full) out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone() igrad_full = inp_full.grad.clone()
#############################################################
# distributed conv # distributed conv
#############################################################
# FWD pass # FWD pass
inp_local = self._split_helper(inp_full) inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True inp_local.requires_grad = True
...@@ -315,9 +267,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -315,9 +267,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_local.backward(ograd_local) out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone() igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass # evaluate FWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist) out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
...@@ -325,9 +275,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -325,9 +275,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
print(f"final relative error of output: {err.item()}") print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass # evaluate BWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist) igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
......
...@@ -200,35 +200,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -200,35 +200,7 @@ class TestDistributedResampling(unittest.TestCase):
def test_distributed_resampling( def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose
): ):
"""
Test the distributed resampling module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
grid_in : str
Grid type for input
grid_out : str
Grid type for output
mode : str
Resampling mode
tol : float
Tolerance for numerical equivalence
verbose : bool
Whether to print verbose output
"""
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
res_args = dict( res_args = dict(
...@@ -248,9 +220,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -248,9 +220,7 @@ class TestDistributedResampling(unittest.TestCase):
# create tensors # create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local conv # local conv
#############################################################
# FWD pass # FWD pass
inp_full.requires_grad = True inp_full.requires_grad = True
out_full = res_local(inp_full) out_full = res_local(inp_full)
...@@ -264,9 +234,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -264,9 +234,7 @@ class TestDistributedResampling(unittest.TestCase):
out_full.backward(ograd_full) out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone() igrad_full = inp_full.grad.clone()
#############################################################
# distributed conv # distributed conv
#############################################################
# FWD pass # FWD pass
inp_local = self._split_helper(inp_full) inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True inp_local.requires_grad = True
...@@ -278,9 +246,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -278,9 +246,7 @@ class TestDistributedResampling(unittest.TestCase):
out_local.backward(ograd_local) out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone() igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass # evaluate FWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist) out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
...@@ -288,9 +254,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -288,9 +254,7 @@ class TestDistributedResampling(unittest.TestCase):
print(f"final relative error of output: {err.item()}") print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass # evaluate BWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist) igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist)
......
...@@ -215,27 +215,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -215,27 +215,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
] ]
) )
def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol): def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
"""
Test the distributed spherical harmonic transform.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
B, C, H, W = batch_size, num_chan, nlat, nlon B, C, H, W = batch_size, num_chan, nlat, nlon
# set up handles # set up handles
...@@ -252,9 +232,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -252,9 +232,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
else: else:
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local transform # local transform
#############################################################
# FWD pass # FWD pass
inp_full.requires_grad = True inp_full.requires_grad = True
out_full = forward_transform_local(inp_full) out_full = forward_transform_local(inp_full)
...@@ -268,9 +246,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -268,9 +246,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full.backward(ograd_full) out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone() igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform # distributed transform
#############################################################
# FWD pass # FWD pass
inp_local = self._split_helper(inp_full) inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True inp_local.requires_grad = True
...@@ -282,9 +258,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -282,9 +258,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local.backward(ograd_local) out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone() igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass # evaluate FWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, forward_transform_dist, vector) out_gather_full = self._gather_helper_fwd(out_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
...@@ -292,9 +266,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -292,9 +266,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print(f"final relative error of output: {err.item()}") print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass # evaluate BWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, forward_transform_dist, vector) igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
...@@ -323,26 +295,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -323,26 +295,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
] ]
) )
def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol): def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
"""
Test the distributed inverse spherical harmonic transform.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
B, C, H, W = batch_size, num_chan, nlat, nlon B, C, H, W = batch_size, num_chan, nlat, nlon
...@@ -383,9 +335,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -383,9 +335,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full.backward(ograd_full) out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone() igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform # distributed transform
#############################################################
# FWD pass # FWD pass
inp_local = self._split_helper(inp_full) inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True inp_local.requires_grad = True
...@@ -397,9 +347,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -397,9 +347,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local.backward(ograd_local) out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone() igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass # evaluate FWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
out_gather_full = self._gather_helper_bwd(out_local, B, C, backward_transform_dist, vector) out_gather_full = self._gather_helper_bwd(out_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
...@@ -407,9 +355,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -407,9 +355,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print(f"final relative error of output: {err.item()}") print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass # evaluate BWD pass
#############################################################
with torch.no_grad(): with torch.no_grad():
igrad_gather_full = self._gather_helper_fwd(igrad_local, B, C, backward_transform_dist, vector) igrad_gather_full = self._gather_helper_fwd(igrad_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
......
...@@ -65,14 +65,6 @@ class TestLegendrePolynomials(unittest.TestCase): ...@@ -65,14 +65,6 @@ class TestLegendrePolynomials(unittest.TestCase):
self.tol = 1e-9 self.tol = 1e-9
def test_legendre(self, verbose=False): def test_legendre(self, verbose=False):
"""
Test the computation of associated Legendre polynomials.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
if verbose: if verbose:
print("Testing computation of associated Legendre polynomials") print("Testing computation of associated Legendre polynomials")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment