Commit 9c26a6d8 authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

removed docstrings from internal functions

parent 328200ab
...@@ -47,29 +47,6 @@ if torch.cuda.is_available(): ...@@ -47,29 +47,6 @@ if torch.cuda.is_available():
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9): def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
Parameters
----------
psi : torch.Tensor
Convolution tensor
quad_weights : torch.Tensor
Quadrature weights
transpose_normalization : bool, optional
Whether to transpose the normalization, by default False
basis_norm_mode : str, optional
Basis normalization mode, by default "none"
merge_quadrature : bool, optional
Whether to merge the quadrature, by default False
eps : float, optional
Epsilon for numerical stability, by default 1e-9
Returns
-------
torch.Tensor
Normalized convolution tensor
"""
kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
correction_factor = nlon_out / nlon_in correction_factor = nlon_out / nlon_in
...@@ -118,38 +95,6 @@ def _precompute_convolution_tensor_dense( ...@@ -118,38 +95,6 @@ def _precompute_convolution_tensor_dense(
basis_norm_mode="none", basis_norm_mode="none",
merge_quadrature=False, merge_quadrature=False,
): ):
"""
Helper routine to compute the convolution Tensor in a dense fashion
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
filter_basis : FilterBasis
Filter basis
grid_in : str
Grid type for input
grid_out : str
Grid type for output
theta_cutoff : float, optional
Theta cutoff
theta_eps : float, optional
Theta epsilon
transpose_normalization : bool, optional
Whether to transpose the normalization, by default False
basis_norm_mode : str, optional
Basis normalization mode, by default "none"
merge_quadrature : bool, optional
Whether to merge the quadrature, by default False
Returns
-------
torch.Tensor
Convolution tensor
"""
assert len(in_shape) == 2 assert len(in_shape) == 2
assert len(out_shape) == 2 assert len(out_shape) == 2
......
...@@ -156,27 +156,10 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -156,27 +156,10 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
"""
Tear down the distributed convolution test.
Parameters
----------
cls : TestDistributedDiscreteContinuousConvolution
The test class instance
"""
thd.finalize() thd.finalize()
dist.destroy_process_group(None) dist.destroy_process_group(None)
def _split_helper(self, tensor): def _split_helper(self, tensor):
"""
Split the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to split
"""
with torch.no_grad(): with torch.no_grad():
# split in W # split in W
...@@ -190,20 +173,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -190,20 +173,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_local return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist): def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
"""
Gather the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedDiscreteContinuousConvTransposeS2 or thd.DistributedDiscreteContinuousConvS2
The distributed convolution object
"""
# we need the shapes # we need the shapes
lat_shapes = convolution_dist.lat_out_shapes lat_shapes = convolution_dist.lat_out_shapes
...@@ -232,20 +201,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -232,20 +201,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_gather return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, convolution_dist): def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
"""
Gather the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedDiscreteContinuousConvTransposeS2 or thd.DistributedDiscreteContinuousConvS2
The distributed convolution object
"""
# we need the shapes # we need the shapes
lat_shapes = convolution_dist.lat_in_shapes lat_shapes = convolution_dist.lat_in_shapes
lon_shapes = convolution_dist.lon_in_shapes lon_shapes = convolution_dist.lon_in_shapes
......
...@@ -146,19 +146,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -146,19 +146,7 @@ class TestDistributedResampling(unittest.TestCase):
dist.destroy_process_group(None) dist.destroy_process_group(None)
def _split_helper(self, tensor): def _split_helper(self, tensor):
"""
Split the tensor along the last dimension into chunks along the W dimension, and then along the H dimension.
Parameters
----------
tensor : torch.Tensor
The tensor to split
Returns
-------
torch.Tensor
The split tensor
"""
with torch.no_grad(): with torch.no_grad():
# split in W # split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w) tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
...@@ -171,25 +159,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -171,25 +159,7 @@ class TestDistributedResampling(unittest.TestCase):
return tensor_local return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist): def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedResampleS2
The distributed resampling object
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes # we need the shapes
lat_shapes = convolution_dist.lat_out_shapes lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes lon_shapes = convolution_dist.lon_out_shapes
...@@ -217,25 +187,6 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -217,25 +187,6 @@ class TestDistributedResampling(unittest.TestCase):
return tensor_gather return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, resampling_dist): def _gather_helper_bwd(self, tensor, B, C, resampling_dist):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
resampling_dist : thd.DistributedResampleS2
The distributed resampling object
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes # we need the shapes
lat_shapes = resampling_dist.lat_in_shapes lat_shapes = resampling_dist.lat_in_shapes
......
...@@ -139,19 +139,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -139,19 +139,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
dist.destroy_process_group(None) dist.destroy_process_group(None)
def _split_helper(self, tensor): def _split_helper(self, tensor):
"""
Split the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to split
Returns
-------
torch.Tensor
The split tensor
"""
with torch.no_grad(): with torch.no_grad():
# split in W # split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w) tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
...@@ -164,27 +151,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -164,27 +151,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return tensor_local return tensor_local
def _gather_helper_fwd(self, tensor, B, C, transform_dist, vector): def _gather_helper_fwd(self, tensor, B, C, transform_dist, vector):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
transform_dist : thd.DistributedRealSHT or thd.DistributedRealVectorSHT
The distributed transform
vector : bool
Whether to use vector spherical harmonic transform
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes # we need the shapes
l_shapes = transform_dist.l_shapes l_shapes = transform_dist.l_shapes
m_shapes = transform_dist.m_shapes m_shapes = transform_dist.m_shapes
...@@ -216,27 +182,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase): ...@@ -216,27 +182,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return tensor_gather return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, transform_dist, vector): def _gather_helper_bwd(self, tensor, B, C, transform_dist, vector):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
transform_dist : thd.DistributedRealSHT or thd.DistributedRealVectorSHT
The distributed transform
vector : bool
Whether to use vector spherical harmonic transform
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes # we need the shapes
lat_shapes = transform_dist.lat_shapes lat_shapes = transform_dist.lat_shapes
......
...@@ -42,7 +42,35 @@ except ImportError as err: ...@@ -42,7 +42,35 @@ except ImportError as err:
# some helper functions # some helper functions
def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False): def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False):
"""Creates a sparse tensor for spherical harmonic convolution operations.
This function constructs a sparse COO tensor from indices and values, with optional
semi-transposition for computational efficiency in spherical harmonic convolutions.
Args:
kernel_size: Number of kernel elements.
psi_idx: Tensor of shape (3, n_nonzero) containing the indices for the sparse tensor.
The three dimensions represent [kernel_idx, lat_idx, combined_lat_lon_idx].
psi_vals: Tensor of shape (n_nonzero,) containing the values for the sparse tensor.
nlat_in: Number of input latitude points.
nlon_in: Number of input longitude points.
nlat_out: Number of output latitude points.
nlon_out: Number of output longitude points.
nlat_in_local: Local number of input latitude points. If None, defaults to nlat_in.
nlat_out_local: Local number of output latitude points. If None, defaults to nlat_out.
semi_transposed: If True, performs a semi-transposition to facilitate computation
by flipping the longitude axis and reorganizing indices.
Returns:
torch.Tensor: A sparse COO tensor of shape (kernel_size, nlat_out_local, nlat_in_local * nlon)
where nlon is either nlon_in or nlon_out depending on semi_transposed flag.
The tensor is coalesced to remove duplicate indices.
Note:
When semi_transposed=True, the function performs a partial transpose operation
that flips the longitude axis and reorganizes the indices to facilitate
efficient spherical harmonic convolution computations.
"""
nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in
nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out
...@@ -141,25 +169,7 @@ def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor ...@@ -141,25 +169,7 @@ def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor
def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
"""
assert len(psi.shape) == 3 assert len(psi.shape) == 3
assert len(x.shape) == 4 assert len(x.shape) == 4
psi = psi.to(x.device) psi = psi.to(x.device)
...@@ -191,25 +201,6 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in ...@@ -191,25 +201,6 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
"""
assert len(psi.shape) == 3 assert len(psi.shape) == 3
assert len(x.shape) == 5 assert len(x.shape) == 5
psi = psi.to(x.device) psi = psi.to(x.device)
......
...@@ -50,41 +50,6 @@ except ImportError as err: ...@@ -50,41 +50,6 @@ except ImportError as err:
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
"""
Forward pass implementation of neighborhood attention on the sphere (S2).
This function computes the neighborhood attention operation using sparse tensor
operations. It implements the attention mechanism with softmax normalization
and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi) where B is batch size, C is channels,
Hi is input height (latitude), Wi is input width (longitude)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo) where Ho is output height, Wo is output width
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor with shape (B, C, Ho, Wo) after neighborhood attention computation
"""
# prepare result tensor # prepare result tensor
y = torch.zeros_like(qy) y = torch.zeros_like(qy)
...@@ -135,41 +100,6 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: ...@@ -135,41 +100,6 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int): nlon_in: int, nlat_out: int, nlon_out: int):
"""
Backward pass implementation for value gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the value tensor (vx).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the value tensor with shape (B, C, Hi, Wi)
"""
# shapes: # shapes:
# input # input
...@@ -238,42 +168,6 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, ...@@ -238,42 +168,6 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int): nlon_in: int, nlat_out: int, nlon_out: int):
"""
Backward pass implementation for key gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the key tensor (kx).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the key tensor with shape (B, C, Hi, Wi)
"""
# shapes: # shapes:
# input # input
# kx: B, C, Hi, Wi # kx: B, C, Hi, Wi
...@@ -354,41 +248,7 @@ def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, ...@@ -354,41 +248,7 @@ def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor,
def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int): nlon_in: int, nlat_out: int, nlon_out: int):
"""
Backward pass implementation for query gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the query tensor (qy).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the query tensor with shape (B, C, Ho, Wo)
"""
# shapes: # shapes:
# input # input
...@@ -581,52 +441,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch. ...@@ -581,52 +441,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor, bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
"""
Torch implementation of neighborhood attention on the sphere (S2).
This function provides a wrapper around the CPU autograd function for
neighborhood attention operations using sparse tensor computations.
Parameters
-----------
k : torch.Tensor
Key tensor
v : torch.Tensor
Value tensor
q : torch.Tensor
Query tensor
wk : torch.Tensor
Key weight tensor
wv : torch.Tensor
Value weight tensor
wq : torch.Tensor
Query weight tensor
bk : torch.Tensor or None
Key bias tensor (optional)
bv : torch.Tensor or None
Value bias tensor (optional)
bq : torch.Tensor or None
Query bias tensor (optional)
quad_weights : torch.Tensor
Quadrature weights for spherical integration
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nh : int
Number of attention heads
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor after neighborhood attention computation
"""
return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq, return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off, quad_weights, col_idx, row_off,
nh, nlon_in, nlat_out, nlon_out) nh, nlon_in, nlat_out, nlon_out)
...@@ -768,54 +583,7 @@ def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.T ...@@ -768,54 +583,7 @@ def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.T
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor, bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int, col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
"""
CUDA implementation of neighborhood attention on the sphere (S2).
This function provides a wrapper around the CUDA autograd function for
neighborhood attention operations using custom CUDA kernels for efficient GPU computation.
Parameters
-----------
k : torch.Tensor
Key tensor
v : torch.Tensor
Value tensor
q : torch.Tensor
Query tensor
wk : torch.Tensor
Key weight tensor
wv : torch.Tensor
Value weight tensor
wq : torch.Tensor
Query weight tensor
bk : torch.Tensor or None
Key bias tensor (optional)
bv : torch.Tensor or None
Value bias tensor (optional)
bq : torch.Tensor or None
Query bias tensor (optional)
quad_weights : torch.Tensor
Quadrature weights for spherical integration
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
max_psi_nnz : int
Maximum number of non-zero elements in sparse tensor
nh : int
Number of attention heads
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor after neighborhood attention computation
"""
return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq, return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off, max_psi_nnz, quad_weights, col_idx, row_off, max_psi_nnz,
nh, nlon_in, nlat_out, nlon_out) nh, nlon_in, nlat_out, nlon_out)
...@@ -60,39 +60,30 @@ except ImportError as err: ...@@ -60,39 +60,30 @@ except ImportError as err:
def _normalize_convolution_tensor_s2( def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9 psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
): ):
""" """Normalizes convolution tensor values based on specified normalization mode.
Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
- "none": No normalization is applied. This function applies different normalization strategies to the convolution tensor
- "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1. values based on the basis_norm_mode parameter. It can normalize individual basis
- "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean. functions, compute mean normalization across all basis functions, or use support
weights. The function also optionally merges quadrature weights into the tensor.
Parameters
----------- Args:
psi_idx: torch.Tensor psi_idx: Index tensor for the sparse convolution tensor.
Index tensor of the convolution tensor psi_vals: Value tensor for the sparse convolution tensor.
psi_vals: torch.Tensor in_shape: Tuple of (nlat_in, nlon_in) representing input grid dimensions.
Values tensor of the convolution tensor out_shape: Tuple of (nlat_out, nlon_out) representing output grid dimensions.
in_shape: Tuple[int] kernel_size: Number of kernel basis functions.
Input shape of the convolution tensor quad_weights: Quadrature weights for numerical integration.
out_shape: Tuple[int] transpose_normalization: If True, applies normalization in transpose direction.
Output shape of the convolution tensor basis_norm_mode: Normalization mode, one of ["none", "individual", "mean", "support"].
kernel_size: int merge_quadrature: If True, multiplies values by quadrature weights.
Size of the kernel eps: Small epsilon value to prevent division by zero.
quad_weights: torch.Tensor
Quadrature weights Returns:
transpose_normalization: bool torch.Tensor: Normalized convolution tensor values.
Whether to normalize the convolution tensor in the transpose direction
basis_norm_mode: str Raises:
Mode for basis normalization ValueError: If basis_norm_mode is not one of the supported modes.
merge_quadrature: bool
Whether to merge the quadrature weights into the convolution tensor
eps: float
Small epsilon to avoid division by zero
Returns
-------
psi_vals: torch.Tensor
Normalized convolution tensor
""" """
# exit here if no normalization is needed # exit here if no normalization is needed
......
...@@ -39,17 +39,14 @@ from torch_harmonics.cache import lru_cache ...@@ -39,17 +39,14 @@ from torch_harmonics.cache import lru_cache
def _circle_dist(x1: torch.Tensor, x2: torch.Tensor): def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
"""Helper function to compute the distance on a circle"""
return torch.minimum(torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2))) return torch.minimum(torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2)))
def _log_factorial(x: torch.Tensor): def _log_factorial(x: torch.Tensor):
"""Helper function to compute the log factorial on a torch tensor"""
return torch.lgamma(x + 1) return torch.lgamma(x + 1)
def _factorial(x: torch.Tensor): def _factorial(x: torch.Tensor):
"""Helper function to compute the factorial on a torch tensor"""
return torch.exp(_log_factorial(x)) return torch.exp(_log_factorial(x))
...@@ -62,27 +59,13 @@ class FilterBasis(metaclass=abc.ABCMeta): ...@@ -62,27 +59,13 @@ class FilterBasis(metaclass=abc.ABCMeta):
self, self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
): ):
"""
Initialize the filter basis.
Parameters
-----------
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel, can be an integer or tuple of integers
"""
self.kernel_shape = kernel_shape self.kernel_shape = kernel_shape
@property @property
@abc.abstractmethod @abc.abstractmethod
def kernel_size(self): def kernel_size(self):
"""
Abstract property that should return the size of the kernel.
Returns
-------
kernel_size: int
The size of the kernel
"""
raise NotImplementedError raise NotImplementedError
# @abc.abstractmethod # @abc.abstractmethod
...@@ -94,10 +77,7 @@ class FilterBasis(metaclass=abc.ABCMeta): ...@@ -94,10 +77,7 @@ class FilterBasis(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the kernel's support and returns both indices and values.
This routine is designed for sparse evaluations of the filter basis.
"""
raise NotImplementedError raise NotImplementedError
...@@ -124,12 +104,7 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -124,12 +104,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
self, self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
): ):
"""
Initialize the piecewise linear filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, int): if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape] kernel_shape = [kernel_shape]
if len(kernel_shape) == 1: if len(kernel_shape) == 1:
...@@ -152,9 +127,6 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -152,9 +127,6 @@ class PiecewiseLinearFilterBasis(FilterBasis):
return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2 return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function # enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
...@@ -176,9 +148,6 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -176,9 +148,6 @@ class PiecewiseLinearFilterBasis(FilterBasis):
return iidx, vals return iidx, vals
def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function # enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
...@@ -253,14 +222,7 @@ class MorletFilterBasis(FilterBasis): ...@@ -253,14 +222,7 @@ class MorletFilterBasis(FilterBasis):
self, self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
): ):
"""
Initialize the Morlet filter basis.
Parameters
-----------
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, int): if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape, kernel_shape] kernel_shape = [kernel_shape, kernel_shape]
if len(kernel_shape) != 2: if len(kernel_shape) != 2:
...@@ -270,56 +232,18 @@ class MorletFilterBasis(FilterBasis): ...@@ -270,56 +232,18 @@ class MorletFilterBasis(FilterBasis):
@property @property
def kernel_size(self): def kernel_size(self):
"""
Compute the kernel size for Morlet basis.
Returns
-------
kernel_size: int
The size of the kernel
"""
return self.kernel_shape[0] * self.kernel_shape[1] return self.kernel_shape[0] * self.kernel_shape[1]
def gaussian_window(self, r: torch.Tensor, width: float = 1.0): def gaussian_window(self, r: torch.Tensor, width: float = 1.0):
"""
Compute Gaussian window function.
Parameters
-----------
r: torch.Tensor
Radial distance tensor
width: float
Width parameter of the Gaussian
Returns
-------
out: torch.Tensor
Gaussian window values
"""
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2)) return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def hann_window(self, r: torch.Tensor, width: float = 1.0): def hann_window(self, r: torch.Tensor, width: float = 1.0):
"""
Compute Hann window function.
Parameters
-----------
r: torch.Tensor
Radial distance tensor
width: float
Width parameter of the Hann window
Returns
-------
out: torch.Tensor
Hann window values
"""
return torch.cos(0.5 * torch.pi * r / width) ** 2 return torch.cos(0.5 * torch.pi * r / width) ** 2
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function # enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
...@@ -355,14 +279,7 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -355,14 +279,7 @@ class ZernikeFilterBasis(FilterBasis):
self, self,
kernel_shape: Union[int, Tuple[int]], kernel_shape: Union[int, Tuple[int]],
): ):
"""
Initialize the Zernike filter basis.
Parameters
-----------
kernel_shape: Union[int, Tuple[int]]
Shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list): if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list):
kernel_shape = kernel_shape[0] kernel_shape = kernel_shape[0]
if not isinstance(kernel_shape, int): if not isinstance(kernel_shape, int):
...@@ -372,34 +289,11 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -372,34 +289,11 @@ class ZernikeFilterBasis(FilterBasis):
@property @property
def kernel_size(self): def kernel_size(self):
"""
Compute the kernel size for Zernike basis.
Returns
-------
kernel_size: int
The size of the kernel
"""
return (self.kernel_shape * (self.kernel_shape + 1)) // 2 return (self.kernel_shape * (self.kernel_shape + 1)) // 2
def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor): def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor):
"""
Compute radial Zernike polynomials.
Parameters
-----------
r: torch.Tensor
Radial distance tensor
n: torch.Tensor
Principal quantum number
m: torch.Tensor
Azimuthal quantum number
Returns
-------
out: torch.Tensor
Radial Zernike polynomial values
"""
out = torch.zeros_like(r) out = torch.zeros_like(r)
bound = (n - m) // 2 + 1 bound = (n - m) // 2 + 1
max_bound = bound.max().item() max_bound = bound.max().item()
...@@ -412,32 +306,11 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -412,32 +306,11 @@ class ZernikeFilterBasis(FilterBasis):
return out return out
def zernikepoly(self, r: torch.Tensor, phi: torch.Tensor, n: torch.Tensor, l: torch.Tensor): def zernikepoly(self, r: torch.Tensor, phi: torch.Tensor, n: torch.Tensor, l: torch.Tensor):
"""
Compute Zernike polynomials.
Parameters
-----------
r: torch.Tensor
Radial distance tensor
phi: torch.Tensor
Azimuthal angle tensor
n: torch.Tensor
Principal quantum number
l: torch.Tensor
Azimuthal quantum number
Returns
-------
out: torch.Tensor
Zernike polynomial values
"""
m = 2 * l - n m = 2 * l - n
return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi)) return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi))
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function # enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
......
...@@ -83,19 +83,6 @@ def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[floa ...@@ -83,19 +83,6 @@ def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[floa
@lru_cache(typed=True, copy=True) @lru_cache(typed=True, copy=True)
def _precompute_longitudes(nlon: int): def _precompute_longitudes(nlon: int):
r"""
Convenience routine to precompute longitudes
Parameters
-----------
nlon: int
Number of longitude points
Returns
-------
lons: torch.Tensor
Tensor of longitude points
"""
lons = torch.linspace(0, 2 * math.pi, nlon+1, dtype=torch.float64, requires_grad=False)[:-1] lons = torch.linspace(0, 2 * math.pi, nlon+1, dtype=torch.float64, requires_grad=False)[:-1]
return lons return lons
...@@ -103,23 +90,6 @@ def _precompute_longitudes(nlon: int): ...@@ -103,23 +90,6 @@ def _precompute_longitudes(nlon: int):
@lru_cache(typed=True, copy=True) @lru_cache(typed=True, copy=True)
def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple[torch.Tensor, torch.Tensor]: def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Convenience routine to precompute latitudes
Parameters
-----------
nlat: int
Number of latitude points
grid: Optional[str]
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
Returns
-------
lats: torch.Tensor
Tensor of latitude points
wlg: torch.Tensor
Tensor of quadrature weights
"""
# compute coordinates in the cosine theta domain # compute coordinates in the cosine theta domain
xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False) xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment