Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
c44d4b23
Commit
c44d4b23
authored
Jul 17, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
simplified docstrings for test classes
parent
ca46b9d2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
7 additions
and
167 deletions
+7
-167
tests/test_attention.py
tests/test_attention.py
+1
-25
tests/test_cache.py
tests/test_cache.py
+0
-12
tests/test_convolution.py
tests/test_convolution.py
+1
-32
tests/test_distributed_convolution.py
tests/test_distributed_convolution.py
+1
-34
tests/test_distributed_resample.py
tests/test_distributed_resample.py
+1
-28
tests/test_distributed_sht.py
tests/test_distributed_sht.py
+1
-20
tests/test_sht.py
tests/test_sht.py
+2
-16
No files found.
tests/test_attention.py
View file @
c44d4b23
...
...
@@ -67,31 +67,7 @@ _perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
@
parameterized_class
((
"device"
),
_devices
)
class
TestNeighborhoodAttentionS2
(
unittest
.
TestCase
):
"""
Test the neighborhood attention module.
Parameters
----------
batch_size : int
Batch size
channels : int
Number of channels
heads : int
Number of heads
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
grid_in : str
Grid type for input
grid_out : str
Grid type for output
atol : float
Absolute tolerance for numerical equivalence
rtol : float
Relative tolerance for numerical equivalence
verbose : bool, optional
Whether to print verbose output, by default True
"""
"""Test the neighborhood attention module (CPU/CUDA if available)."""
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
...
...
tests/test_cache.py
View file @
c44d4b23
...
...
@@ -36,19 +36,7 @@ import torch
class
TestCacheConsistency
(
unittest
.
TestCase
):
def
test_consistency
(
self
,
verbose
=
False
):
"""
Test that cached values are not modified externally.
This test verifies that the LRU cache decorator properly handles
deep copying to prevent unintended modifications to cached objects.
Parameters
-----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
if
verbose
:
print
(
"Testing that cache values does not get modified externally"
)
from
torch_harmonics.legendre
import
_precompute_legpoly
...
...
tests/test_convolution.py
View file @
c44d4b23
...
...
@@ -161,38 +161,7 @@ def _precompute_convolution_tensor_dense(
@
parameterized_class
((
"device"
),
_devices
)
class
TestDiscreteContinuousConvolution
(
unittest
.
TestCase
):
"""
Test the discrete-continuous convolution module.
Parameters
----------
batch_size : int
Batch size
in_channels : int
Number of input channels
out_channels : int
Number of output channels
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
verbose : bool, optional
Whether to print verbose output, by default False
"""
"""Test the discrete-continuous convolution module (CPU/CUDA if available)."""
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
...
...
tests/test_distributed_convolution.py
View file @
c44d4b23
...
...
@@ -41,40 +41,7 @@ import torch_harmonics.distributed as thd
class
TestDistributedDiscreteContinuousConvolution
(
unittest
.
TestCase
):
"""
Test the distributed discrete-continuous convolution module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
groups : int
Number of groups
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
"""
"""Test the distributed discrete-continuous convolution module."""
@
classmethod
def
setUpClass
(
cls
):
...
...
tests/test_distributed_resample.py
View file @
c44d4b23
...
...
@@ -41,34 +41,7 @@ import torch_harmonics.distributed as thd
class
TestDistributedResampling
(
unittest
.
TestCase
):
"""
Test the distributed resampling module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
grid_in : str
Grid type for input
grid_out : str
Grid type for output
mode : str
Resampling mode
tol : float
Tolerance for numerical equivalence
verbose : bool
Whether to print verbose output
"""
"""Test the distributed resampling module (CPU/CUDA if available)."""
@
classmethod
def
setUpClass
(
cls
):
...
...
tests/test_distributed_sht.py
View file @
c44d4b23
...
...
@@ -41,26 +41,7 @@ import torch_harmonics.distributed as thd
class
TestDistributedSphericalHarmonicTransform
(
unittest
.
TestCase
):
"""
Test the distributed spherical harmonic transform module.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
"""Test the distributed spherical harmonic transform module (CPU/CUDA if available)."""
@
classmethod
def
setUpClass
(
cls
):
...
...
tests/test_sht.py
View file @
c44d4b23
...
...
@@ -42,14 +42,7 @@ if torch.cuda.is_available():
class
TestLegendrePolynomials
(
unittest
.
TestCase
):
"""
Test the associated Legendre polynomials.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
"""Test the associated Legendre polynomials (CPU/CUDA if available)."""
def
setUp
(
self
):
self
.
cml
=
lambda
m
,
l
:
math
.
sqrt
((
2
*
l
+
1
)
/
4
/
math
.
pi
)
*
math
.
sqrt
(
math
.
factorial
(
l
-
m
)
/
math
.
factorial
(
l
+
m
))
self
.
pml
=
dict
()
...
...
@@ -94,14 +87,7 @@ class TestLegendrePolynomials(unittest.TestCase):
@
parameterized_class
((
"device"
),
_devices
)
class
TestSphericalHarmonicTransform
(
unittest
.
TestCase
):
"""
Test the spherical harmonic transform.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
"""Test the spherical harmonic transform (CPU/CUDA if available)."""
def
setUp
(
self
):
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
"cuda"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment