Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
82e7860d
Commit
82e7860d
authored
Jun 30, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Improved docstrings in tests
parent
290da8e0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
485 additions
and
6 deletions
+485
-6
tests/test_attention.py
tests/test_attention.py
+26
-0
tests/test_convolution.py
tests/test_convolution.py
+81
-0
tests/test_distributed_convolution.py
tests/test_distributed_convolution.py
+124
-0
tests/test_distributed_resample.py
tests/test_distributed_resample.py
+108
-0
tests/test_distributed_sht.py
tests/test_distributed_sht.py
+118
-1
tests/test_sht.py
tests/test_sht.py
+28
-5
No files found.
tests/test_attention.py
View file @
82e7860d
...
...
@@ -67,6 +67,32 @@ _perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
@
parameterized_class
((
"device"
),
_devices
)
class
TestNeighborhoodAttentionS2
(
unittest
.
TestCase
):
"""
Test the neighborhood attention module.
Parameters
----------
batch_size : int
Batch size
channels : int
Number of channels
heads : int
Number of heads
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
grid_in : str
Grid type for input
grid_out : str
Grid type for output
atol : float
Absolute tolerance for numerical equivalence
rtol : float
Relative tolerance for numerical equivalence
verbose : bool, optional
Whether to print verbose output, by default True
"""
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
if
self
.
device
.
type
==
"cuda"
:
...
...
tests/test_convolution.py
View file @
82e7860d
...
...
@@ -49,6 +49,26 @@ if torch.cuda.is_available():
def
_normalize_convolution_tensor_dense
(
psi
,
quad_weights
,
transpose_normalization
=
False
,
basis_norm_mode
=
"none"
,
merge_quadrature
=
False
,
eps
=
1e-9
):
"""
Discretely normalizes the convolution tensor.
Parameters
----------
psi : torch.Tensor
Convolution tensor
quad_weights : torch.Tensor
Quadrature weights
transpose_normalization : bool, optional
Whether to transpose the normalization, by default False
basis_norm_mode : str, optional
Basis normalization mode, by default "none"
merge_quadrature : bool, optional
Whether to merge the quadrature, by default False
eps : float, optional
Epsilon for numerical stability, by default 1e-9
Returns
-------
torch.Tensor
Normalized convolution tensor
"""
kernel_size
,
nlat_out
,
nlon_out
,
nlat_in
,
nlon_in
=
psi
.
shape
...
...
@@ -100,6 +120,34 @@ def _precompute_convolution_tensor_dense(
):
"""
Helper routine to compute the convolution Tensor in a dense fashion
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
filter_basis : FilterBasis
Filter basis
grid_in : str
Grid type for input
grid_out : str
Grid type for output
theta_cutoff : float, optional
Theta cutoff
theta_eps : float, optional
Theta epsilon
transpose_normalization : bool, optional
Whether to transpose the normalization, by default False
basis_norm_mode : str, optional
Basis normalization mode, by default "none"
merge_quadrature : bool, optional
Whether to merge the quadrature, by default False
Returns
-------
torch.Tensor
Convolution tensor
"""
assert
len
(
in_shape
)
==
2
...
...
@@ -168,6 +216,39 @@ def _precompute_convolution_tensor_dense(
@
parameterized_class
((
"device"
),
_devices
)
class
TestDiscreteContinuousConvolution
(
unittest
.
TestCase
):
"""
Test the discrete-continuous convolution module.
Parameters
----------
batch_size : int
Batch size
in_channels : int
Number of input channels
out_channels : int
Number of output channels
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
verbose : bool, optional
Whether to print verbose output, by default False
"""
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
if
self
.
device
.
type
==
"cuda"
:
...
...
tests/test_distributed_convolution.py
View file @
82e7860d
...
...
@@ -41,9 +41,51 @@ import torch_harmonics.distributed as thd
class
TestDistributedDiscreteContinuousConvolution
(
unittest
.
TestCase
):
"""
Test the distributed discrete-continuous convolution module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
groups : int
Number of groups
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
"""
@
classmethod
def
setUpClass
(
cls
):
"""
Set up the distributed convolution test.
Parameters
----------
cls : TestDistributedDiscreteContinuousConvolution
The test class instance
"""
# set up distributed
cls
.
world_rank
=
int
(
os
.
getenv
(
"WORLD_RANK"
,
0
))
...
...
@@ -114,10 +156,28 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
"""
Tear down the distributed convolution test.
Parameters
----------
cls : TestDistributedDiscreteContinuousConvolution
The test class instance
"""
thd
.
finalize
()
dist
.
destroy_process_group
(
None
)
def
_split_helper
(
self
,
tensor
):
"""
Split the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to split
"""
with
torch
.
no_grad
():
# split in W
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor
,
dim
=-
1
,
num_chunks
=
self
.
grid_size_w
)
...
...
@@ -130,6 +190,21 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return
tensor_local
def
_gather_helper_fwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
"""
Gather the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedDiscreteContinuousConvTransposeS2 or thd.DistributedDiscreteContinuousConvS2
The distributed convolution object
"""
# we need the shapes
lat_shapes
=
convolution_dist
.
lat_out_shapes
lon_shapes
=
convolution_dist
.
lon_out_shapes
...
...
@@ -157,6 +232,20 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return
tensor_gather
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
"""
Gather the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedDiscreteContinuousConvTransposeS2 or thd.DistributedDiscreteContinuousConvS2
The distributed convolution object
"""
# we need the shapes
lat_shapes
=
convolution_dist
.
lat_in_shapes
lon_shapes
=
convolution_dist
.
lon_in_shapes
...
...
@@ -205,6 +294,41 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
self
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
batch_size
,
num_chan
,
kernel_shape
,
basis_type
,
basis_norm_mode
,
groups
,
grid_in
,
grid_out
,
transpose
,
tol
):
"""
Test the distributed discrete-continuous convolution module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
groups : int
Number of groups
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
"""
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat_in
,
nlon_in
disco_args
=
dict
(
...
...
tests/test_distributed_resample.py
View file @
82e7860d
...
...
@@ -41,6 +41,34 @@ import torch_harmonics.distributed as thd
class
TestDistributedResampling
(
unittest
.
TestCase
):
"""
Test the distributed resampling module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
grid_in : str
Grid type for input
grid_out : str
Grid type for output
mode : str
Resampling mode
tol : float
Tolerance for numerical equivalence
verbose : bool
Whether to print verbose output
"""
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -118,6 +146,19 @@ class TestDistributedResampling(unittest.TestCase):
dist
.
destroy_process_group
(
None
)
def
_split_helper
(
self
,
tensor
):
"""
Split the tensor along the last dimension into chunks along the W dimension, and then along the H dimension.
Parameters
----------
tensor : torch.Tensor
The tensor to split
Returns
-------
torch.Tensor
The split tensor
"""
with
torch
.
no_grad
():
# split in W
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor
,
dim
=-
1
,
num_chunks
=
self
.
grid_size_w
)
...
...
@@ -130,6 +171,25 @@ class TestDistributedResampling(unittest.TestCase):
return
tensor_local
def
_gather_helper_fwd
(
self
,
tensor
,
B
,
C
,
convolution_dist
):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedResampleS2
The distributed resampling object
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
lat_shapes
=
convolution_dist
.
lat_out_shapes
lon_shapes
=
convolution_dist
.
lon_out_shapes
...
...
@@ -157,6 +217,26 @@ class TestDistributedResampling(unittest.TestCase):
return
tensor_gather
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
resampling_dist
):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
resampling_dist : thd.DistributedResampleS2
The distributed resampling object
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
lat_shapes
=
resampling_dist
.
lat_in_shapes
lon_shapes
=
resampling_dist
.
lon_in_shapes
...
...
@@ -196,6 +276,34 @@ class TestDistributedResampling(unittest.TestCase):
def
test_distributed_resampling
(
self
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
batch_size
,
num_chan
,
grid_in
,
grid_out
,
mode
,
tol
,
verbose
):
"""
Test the distributed resampling module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
grid_in : str
Grid type for input
grid_out : str
Grid type for output
mode : str
Resampling mode
tol : float
Tolerance for numerical equivalence
verbose : bool
Whether to print verbose output
"""
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat_in
,
nlon_in
...
...
tests/test_distributed_sht.py
View file @
82e7860d
...
...
@@ -41,10 +41,29 @@ import torch_harmonics.distributed as thd
class
TestDistributedSphericalHarmonicTransform
(
unittest
.
TestCase
):
"""
Test the distributed spherical harmonic transform module.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
@
classmethod
def
setUpClass
(
cls
):
# set up distributed
cls
.
world_rank
=
int
(
os
.
getenv
(
"WORLD_RANK"
,
0
))
cls
.
grid_size_h
=
int
(
os
.
getenv
(
"GRID_H"
,
1
))
...
...
@@ -120,6 +139,19 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
dist
.
destroy_process_group
(
None
)
def
_split_helper
(
self
,
tensor
):
"""
Split the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to split
Returns
-------
torch.Tensor
The split tensor
"""
with
torch
.
no_grad
():
# split in W
tensor_list_local
=
thd
.
split_tensor_along_dim
(
tensor
,
dim
=-
1
,
num_chunks
=
self
.
grid_size_w
)
...
...
@@ -132,6 +164,27 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return
tensor_local
def
_gather_helper_fwd
(
self
,
tensor
,
B
,
C
,
transform_dist
,
vector
):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
transform_dist : thd.DistributedRealSHT or thd.DistributedRealVectorSHT
The distributed transform
vector : bool
Whether to use vector spherical harmonic transform
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
l_shapes
=
transform_dist
.
l_shapes
m_shapes
=
transform_dist
.
m_shapes
...
...
@@ -163,6 +216,28 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return
tensor_gather
def
_gather_helper_bwd
(
self
,
tensor
,
B
,
C
,
transform_dist
,
vector
):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
transform_dist : thd.DistributedRealSHT or thd.DistributedRealVectorSHT
The distributed transform
vector : bool
Whether to use vector spherical harmonic transform
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
lat_shapes
=
transform_dist
.
lat_shapes
lon_shapes
=
transform_dist
.
lon_shapes
...
...
@@ -214,6 +289,27 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def
test_distributed_sht
(
self
,
nlat
,
nlon
,
batch_size
,
num_chan
,
grid
,
vector
,
tol
):
"""
Test the distributed spherical harmonic transform.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat
,
nlon
# set up handles
...
...
@@ -301,6 +397,27 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def
test_distributed_isht
(
self
,
nlat
,
nlon
,
batch_size
,
num_chan
,
grid
,
vector
,
tol
):
"""
Test the distributed inverse spherical harmonic transform.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat
,
nlon
if
vector
:
...
...
tests/test_sht.py
View file @
82e7860d
...
...
@@ -42,7 +42,14 @@ if torch.cuda.is_available():
class
TestLegendrePolynomials
(
unittest
.
TestCase
):
"""
Test the associated Legendre polynomials.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
def
setUp
(
self
):
self
.
cml
=
lambda
m
,
l
:
math
.
sqrt
((
2
*
l
+
1
)
/
4
/
math
.
pi
)
*
math
.
sqrt
(
math
.
factorial
(
l
-
m
)
/
math
.
factorial
(
l
+
m
))
self
.
pml
=
dict
()
...
...
@@ -65,6 +72,14 @@ class TestLegendrePolynomials(unittest.TestCase):
self
.
tol
=
1e-9
def
test_legendre
(
self
,
verbose
=
False
):
"""
Test the computation of associated Legendre polynomials.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
if
verbose
:
print
(
"Testing computation of associated Legendre polynomials"
)
...
...
@@ -79,11 +94,19 @@ class TestLegendrePolynomials(unittest.TestCase):
@
parameterized_class
((
"device"
),
_devices
)
class
TestSphericalHarmonicTransform
(
unittest
.
TestCase
):
"""
Test the spherical harmonic transform.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
def
setUp
(
self
):
torch
.
manual_seed
(
333
)
if
self
.
device
.
type
==
"cuda"
:
torch
.
cuda
.
manual_seed
(
333
)
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
"cuda"
)
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
@
parameterized
.
expand
(
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment