Commit 4770621b authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

cleanup of tests

parent 30d8b2da
......@@ -45,15 +45,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""
Set up the distributed convolution test.
Parameters
----------
cls : TestDistributedDiscreteContinuousConvolution
The test class instance
"""
# set up distributed
cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1))
......@@ -217,41 +208,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol
):
"""
Test the distributed discrete-continuous convolution module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
groups : int
Number of groups
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
"""
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
disco_args = dict(
......@@ -285,9 +241,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = conv_local(inp_full)
......@@ -301,9 +255,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
......@@ -315,9 +267,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
......@@ -325,9 +275,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
......
......@@ -200,34 +200,6 @@ class TestDistributedResampling(unittest.TestCase):
def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose
):
"""
Test the distributed resampling module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
grid_in : str
Grid type for input
grid_out : str
Grid type for output
mode : str
Resampling mode
tol : float
Tolerance for numerical equivalence
verbose : bool
Whether to print verbose output
"""
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
......@@ -248,9 +220,7 @@ class TestDistributedResampling(unittest.TestCase):
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = res_local(inp_full)
......@@ -264,9 +234,7 @@ class TestDistributedResampling(unittest.TestCase):
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
......@@ -278,9 +246,7 @@ class TestDistributedResampling(unittest.TestCase):
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
......@@ -288,9 +254,7 @@ class TestDistributedResampling(unittest.TestCase):
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist)
......
......@@ -215,26 +215,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
"""
Test the distributed spherical harmonic transform.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
B, C, H, W = batch_size, num_chan, nlat, nlon
......@@ -252,9 +232,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
else:
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local transform
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = forward_transform_local(inp_full)
......@@ -268,9 +246,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
......@@ -282,9 +258,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
......@@ -292,9 +266,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
......@@ -323,26 +295,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
"""
Test the distributed inverse spherical harmonic transform.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
B, C, H, W = batch_size, num_chan, nlat, nlon
......@@ -383,9 +335,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
......@@ -397,9 +347,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_bwd(out_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
......@@ -407,9 +355,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_fwd(igrad_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
......
......@@ -65,14 +65,6 @@ class TestLegendrePolynomials(unittest.TestCase):
self.tol = 1e-9
def test_legendre(self, verbose=False):
"""
Test the computation of associated Legendre polynomials.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
if verbose:
print("Testing computation of associated Legendre polynomials")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment