Commit c44d4b23 authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

simplified docstrings for test classes

parent ca46b9d2
......@@ -67,31 +67,7 @@ _perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
@parameterized_class(("device"), _devices)
class TestNeighborhoodAttentionS2(unittest.TestCase):
"""
Test the neighborhood attention module.
Parameters
----------
batch_size : int
Batch size
channels : int
Number of channels
heads : int
Number of heads
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
grid_in : str
Grid type for input
grid_out : str
Grid type for output
atol : float
Absolute tolerance for numerical equivalence
rtol : float
Relative tolerance for numerical equivalence
verbose : bool, optional
Whether to print verbose output, by default True
"""
"""Test the neighborhood attention module (CPU/CUDA if available)."""
def setUp(self):
torch.manual_seed(333)
......
......@@ -36,19 +36,7 @@ import torch
class TestCacheConsistency(unittest.TestCase):
def test_consistency(self, verbose=False):
"""
Test that cached values are not modified externally.
This test verifies that the LRU cache decorator properly handles
deep copying to prevent unintended modifications to cached objects.
Parameters
-----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
if verbose:
print("Testing that cache values does not get modified externally")
from torch_harmonics.legendre import _precompute_legpoly
......
......@@ -161,38 +161,7 @@ def _precompute_convolution_tensor_dense(
@parameterized_class(("device"), _devices)
class TestDiscreteContinuousConvolution(unittest.TestCase):
"""
Test the discrete-continuous convolution module.
Parameters
----------
batch_size : int
Batch size
in_channels : int
Number of input channels
out_channels : int
Number of output channels
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
verbose : bool, optional
Whether to print verbose output, by default False
"""
"""Test the discrete-continuous convolution module (CPU/CUDA if available)."""
def setUp(self):
torch.manual_seed(333)
......
......@@ -41,40 +41,7 @@ import torch_harmonics.distributed as thd
class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
"""
Test the distributed discrete-continuous convolution module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
groups : int
Number of groups
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
"""
"""Test the distributed discrete-continuous convolution module."""
@classmethod
def setUpClass(cls):
......
......@@ -41,34 +41,7 @@ import torch_harmonics.distributed as thd
class TestDistributedResampling(unittest.TestCase):
"""
Test the distributed resampling module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
grid_in : str
Grid type for input
grid_out : str
Grid type for output
mode : str
Resampling mode
tol : float
Tolerance for numerical equivalence
verbose : bool
Whether to print verbose output
"""
"""Test the distributed resampling module (CPU/CUDA if available)."""
@classmethod
def setUpClass(cls):
......
......@@ -41,26 +41,7 @@ import torch_harmonics.distributed as thd
class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
"""
Test the distributed spherical harmonic transform module.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
"""Test the distributed spherical harmonic transform module (CPU/CUDA if available)."""
@classmethod
def setUpClass(cls):
......
......@@ -42,14 +42,7 @@ if torch.cuda.is_available():
class TestLegendrePolynomials(unittest.TestCase):
"""
Test the associated Legendre polynomials.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
"""Test the associated Legendre polynomials (CPU/CUDA if available)."""
def setUp(self):
self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.pml = dict()
......@@ -94,14 +87,7 @@ class TestLegendrePolynomials(unittest.TestCase):
@parameterized_class(("device"), _devices)
class TestSphericalHarmonicTransform(unittest.TestCase):
"""
Test the spherical harmonic transform.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
"""Test the spherical harmonic transform (CPU/CUDA if available)."""
def setUp(self):
if torch.cuda.is_available():
self.device = torch.device("cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment