Commit 82e7860d authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in tests

parent 290da8e0
......@@ -67,6 +67,32 @@ _perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
@parameterized_class(("device"), _devices)
class TestNeighborhoodAttentionS2(unittest.TestCase):
"""
Test the neighborhood attention module.
Parameters
----------
batch_size : int
Batch size
channels : int
Number of channels
heads : int
Number of heads
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
grid_in : str
Grid type for input
grid_out : str
Grid type for output
atol : float
Absolute tolerance for numerical equivalence
rtol : float
Relative tolerance for numerical equivalence
verbose : bool, optional
Whether to print verbose output, by default True
"""
def setUp(self):
torch.manual_seed(333)
if self.device.type == "cuda":
......
......@@ -49,6 +49,26 @@ if torch.cuda.is_available():
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
Parameters
----------
psi : torch.Tensor
Convolution tensor
quad_weights : torch.Tensor
Quadrature weights
transpose_normalization : bool, optional
Whether to transpose the normalization, by default False
basis_norm_mode : str, optional
Basis normalization mode, by default "none"
merge_quadrature : bool, optional
Whether to merge the quadrature, by default False
eps : float, optional
Epsilon for numerical stability, by default 1e-9
Returns
-------
torch.Tensor
Normalized convolution tensor
"""
kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
......@@ -100,6 +120,34 @@ def _precompute_convolution_tensor_dense(
):
"""
Helper routine to compute the convolution Tensor in a dense fashion
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
filter_basis : FilterBasis
Filter basis
grid_in : str
Grid type for input
grid_out : str
Grid type for output
theta_cutoff : float, optional
Theta cutoff
theta_eps : float, optional
Theta epsilon
transpose_normalization : bool, optional
Whether to transpose the normalization, by default False
basis_norm_mode : str, optional
Basis normalization mode, by default "none"
merge_quadrature : bool, optional
Whether to merge the quadrature, by default False
Returns
-------
torch.Tensor
Convolution tensor
"""
assert len(in_shape) == 2
......@@ -168,6 +216,39 @@ def _precompute_convolution_tensor_dense(
@parameterized_class(("device"), _devices)
class TestDiscreteContinuousConvolution(unittest.TestCase):
"""
Test the discrete-continuous convolution module.
Parameters
----------
batch_size : int
Batch size
in_channels : int
Number of input channels
out_channels : int
Number of output channels
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
verbose : bool, optional
Whether to print verbose output, by default False
"""
def setUp(self):
torch.manual_seed(333)
if self.device.type == "cuda":
......
......@@ -41,9 +41,51 @@ import torch_harmonics.distributed as thd
class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
"""
Test the distributed discrete-continuous convolution module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
groups : int
Number of groups
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
"""
@classmethod
def setUpClass(cls):
"""
Set up the distributed convolution test.
Parameters
----------
cls : TestDistributedDiscreteContinuousConvolution
The test class instance
"""
# set up distributed
cls.world_rank = int(os.getenv("WORLD_RANK", 0))
......@@ -114,10 +156,28 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@classmethod
def tearDownClass(cls):
"""
Tear down the distributed convolution test.
Parameters
----------
cls : TestDistributedDiscreteContinuousConvolution
The test class instance
"""
thd.finalize()
dist.destroy_process_group(None)
def _split_helper(self, tensor):
"""
Split the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to split
"""
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
......@@ -130,6 +190,21 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
"""
Gather the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedDiscreteContinuousConvTransposeS2 or thd.DistributedDiscreteContinuousConvS2
The distributed convolution object
"""
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes
......@@ -157,6 +232,20 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
"""
Gather the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedDiscreteContinuousConvTransposeS2 or thd.DistributedDiscreteContinuousConvS2
The distributed convolution object
"""
# we need the shapes
lat_shapes = convolution_dist.lat_in_shapes
lon_shapes = convolution_dist.lon_in_shapes
......@@ -205,6 +294,41 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol
):
"""
Test the distributed discrete-continuous convolution module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
groups : int
Number of groups
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
"""
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
disco_args = dict(
......
......@@ -41,6 +41,34 @@ import torch_harmonics.distributed as thd
class TestDistributedResampling(unittest.TestCase):
"""
Test the distributed resampling module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
grid_in : str
Grid type for input
grid_out : str
Grid type for output
mode : str
Resampling mode
tol : float
Tolerance for numerical equivalence
verbose : bool
Whether to print verbose output
"""
@classmethod
def setUpClass(cls):
......@@ -118,6 +146,19 @@ class TestDistributedResampling(unittest.TestCase):
dist.destroy_process_group(None)
def _split_helper(self, tensor):
"""
Split the tensor along the last dimension into chunks along the W dimension, and then along the H dimension.
Parameters
----------
tensor : torch.Tensor
The tensor to split
Returns
-------
torch.Tensor
The split tensor
"""
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
......@@ -130,6 +171,25 @@ class TestDistributedResampling(unittest.TestCase):
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedResampleS2
The distributed resampling object
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes
......@@ -157,6 +217,26 @@ class TestDistributedResampling(unittest.TestCase):
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, resampling_dist):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
resampling_dist : thd.DistributedResampleS2
The distributed resampling object
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
lat_shapes = resampling_dist.lat_in_shapes
lon_shapes = resampling_dist.lon_in_shapes
......@@ -196,6 +276,34 @@ class TestDistributedResampling(unittest.TestCase):
def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose
):
"""
Test the distributed resampling module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
grid_in : str
Grid type for input
grid_out : str
Grid type for output
mode : str
Resampling mode
tol : float
Tolerance for numerical equivalence
verbose : bool
Whether to print verbose output
"""
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
......
......@@ -41,10 +41,29 @@ import torch_harmonics.distributed as thd
class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
"""
Test the distributed spherical harmonic transform module.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
@classmethod
def setUpClass(cls):
# set up distributed
cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1))
......@@ -120,6 +139,19 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
dist.destroy_process_group(None)
def _split_helper(self, tensor):
"""
Split the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to split
Returns
-------
torch.Tensor
The split tensor
"""
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
......@@ -132,6 +164,27 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, transform_dist, vector):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
transform_dist : thd.DistributedRealSHT or thd.DistributedRealVectorSHT
The distributed transform
vector : bool
Whether to use vector spherical harmonic transform
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
l_shapes = transform_dist.l_shapes
m_shapes = transform_dist.m_shapes
......@@ -163,6 +216,28 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, transform_dist, vector):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
transform_dist : thd.DistributedRealSHT or thd.DistributedRealVectorSHT
The distributed transform
vector : bool
Whether to use vector spherical harmonic transform
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
lat_shapes = transform_dist.lat_shapes
lon_shapes = transform_dist.lon_shapes
......@@ -214,6 +289,27 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
"""
Test the distributed spherical harmonic transform.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
B, C, H, W = batch_size, num_chan, nlat, nlon
# set up handles
......@@ -301,6 +397,27 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
"""
Test the distributed inverse spherical harmonic transform.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
B, C, H, W = batch_size, num_chan, nlat, nlon
if vector:
......
......@@ -42,7 +42,14 @@ if torch.cuda.is_available():
class TestLegendrePolynomials(unittest.TestCase):
"""
Test the associated Legendre polynomials.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
def setUp(self):
self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.pml = dict()
......@@ -65,6 +72,14 @@ class TestLegendrePolynomials(unittest.TestCase):
self.tol = 1e-9
def test_legendre(self, verbose=False):
"""
Test the computation of associated Legendre polynomials.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
if verbose:
print("Testing computation of associated Legendre polynomials")
......@@ -79,11 +94,19 @@ class TestLegendrePolynomials(unittest.TestCase):
@parameterized_class(("device"), _devices)
class TestSphericalHarmonicTransform(unittest.TestCase):
"""
Test the spherical harmonic transform.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
def setUp(self):
torch.manual_seed(333)
if self.device.type == "cuda":
torch.cuda.manual_seed(333)
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
@parameterized.expand(
[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment