You need to sign in or sign up before continuing.
Commit 9c26a6d8 authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

removed docstrings from internal functions

parent 328200ab
......@@ -47,30 +47,7 @@ if torch.cuda.is_available():
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
Parameters
----------
psi : torch.Tensor
Convolution tensor
quad_weights : torch.Tensor
Quadrature weights
transpose_normalization : bool, optional
Whether to transpose the normalization, by default False
basis_norm_mode : str, optional
Basis normalization mode, by default "none"
merge_quadrature : bool, optional
Whether to merge the quadrature, by default False
eps : float, optional
Epsilon for numerical stability, by default 1e-9
Returns
-------
torch.Tensor
Normalized convolution tensor
"""
kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
correction_factor = nlon_out / nlon_in
......@@ -118,38 +95,6 @@ def _precompute_convolution_tensor_dense(
basis_norm_mode="none",
merge_quadrature=False,
):
"""
Helper routine to compute the convolution Tensor in a dense fashion
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
filter_basis : FilterBasis
Filter basis
grid_in : str
Grid type for input
grid_out : str
Grid type for output
theta_cutoff : float, optional
Theta cutoff
theta_eps : float, optional
Theta epsilon
transpose_normalization : bool, optional
Whether to transpose the normalization, by default False
basis_norm_mode : str, optional
Basis normalization mode, by default "none"
merge_quadrature : bool, optional
Whether to merge the quadrature, by default False
Returns
-------
torch.Tensor
Convolution tensor
"""
assert len(in_shape) == 2
assert len(out_shape) == 2
......
......@@ -156,27 +156,10 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@classmethod
def tearDownClass(cls):
"""
Tear down the distributed convolution test.
Parameters
----------
cls : TestDistributedDiscreteContinuousConvolution
The test class instance
"""
thd.finalize()
dist.destroy_process_group(None)
def _split_helper(self, tensor):
"""
Split the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to split
"""
with torch.no_grad():
# split in W
......@@ -190,20 +173,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
"""
Gather the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedDiscreteContinuousConvTransposeS2 or thd.DistributedDiscreteContinuousConvS2
The distributed convolution object
"""
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
......@@ -232,20 +201,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
"""
Gather the tensor along the horizontal and vertical dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedDiscreteContinuousConvTransposeS2 or thd.DistributedDiscreteContinuousConvS2
The distributed convolution object
"""
# we need the shapes
lat_shapes = convolution_dist.lat_in_shapes
lon_shapes = convolution_dist.lon_in_shapes
......
......@@ -146,19 +146,7 @@ class TestDistributedResampling(unittest.TestCase):
dist.destroy_process_group(None)
def _split_helper(self, tensor):
"""
Split the tensor along the last dimension into chunks along the W dimension, and then along the H dimension.
Parameters
----------
tensor : torch.Tensor
The tensor to split
Returns
-------
torch.Tensor
The split tensor
"""
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
......@@ -171,25 +159,7 @@ class TestDistributedResampling(unittest.TestCase):
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
convolution_dist : thd.DistributedResampleS2
The distributed resampling object
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes
......@@ -217,25 +187,6 @@ class TestDistributedResampling(unittest.TestCase):
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, resampling_dist):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
resampling_dist : thd.DistributedResampleS2
The distributed resampling object
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
lat_shapes = resampling_dist.lat_in_shapes
......
......@@ -139,19 +139,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
dist.destroy_process_group(None)
def _split_helper(self, tensor):
"""
Split the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to split
Returns
-------
torch.Tensor
The split tensor
"""
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
......@@ -164,27 +151,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, transform_dist, vector):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
transform_dist : thd.DistributedRealSHT or thd.DistributedRealVectorSHT
The distributed transform
vector : bool
Whether to use vector spherical harmonic transform
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
l_shapes = transform_dist.l_shapes
m_shapes = transform_dist.m_shapes
......@@ -216,27 +182,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, transform_dist, vector):
"""
Gather the tensor along the W and H dimensions.
Parameters
----------
tensor : torch.Tensor
The tensor to gather
B : int
Batch size
C : int
Number of channels
transform_dist : thd.DistributedRealSHT or thd.DistributedRealVectorSHT
The distributed transform
vector : bool
Whether to use vector spherical harmonic transform
Returns
-------
torch.Tensor
The gathered tensor
"""
# we need the shapes
lat_shapes = transform_dist.lat_shapes
......
......@@ -42,7 +42,35 @@ except ImportError as err:
# some helper functions
def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False):
"""Creates a sparse tensor for spherical harmonic convolution operations.
This function constructs a sparse COO tensor from indices and values, with optional
semi-transposition for computational efficiency in spherical harmonic convolutions.
Args:
kernel_size: Number of kernel elements.
psi_idx: Tensor of shape (3, n_nonzero) containing the indices for the sparse tensor.
The three dimensions represent [kernel_idx, lat_idx, combined_lat_lon_idx].
psi_vals: Tensor of shape (n_nonzero,) containing the values for the sparse tensor.
nlat_in: Number of input latitude points.
nlon_in: Number of input longitude points.
nlat_out: Number of output latitude points.
nlon_out: Number of output longitude points.
nlat_in_local: Local number of input latitude points. If None, defaults to nlat_in.
nlat_out_local: Local number of output latitude points. If None, defaults to nlat_out.
semi_transposed: If True, performs a semi-transposition to facilitate computation
by flipping the longitude axis and reorganizing indices.
Returns:
torch.Tensor: A sparse COO tensor of shape (kernel_size, nlat_out_local, nlat_in_local * nlon)
where nlon is either nlon_in or nlon_out depending on semi_transposed flag.
The tensor is coalesced to remove duplicate indices.
Note:
When semi_transposed=True, the function performs a partial transpose operation
that flips the longitude axis and reorganizes the indices to facilitate
efficient spherical harmonic convolution computations.
"""
nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in
nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out
......@@ -141,25 +169,7 @@ def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor
def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
"""
assert len(psi.shape) == 3
assert len(x.shape) == 4
psi = psi.to(x.device)
......@@ -191,25 +201,6 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
"""
assert len(psi.shape) == 3
assert len(x.shape) == 5
psi = psi.to(x.device)
......
......@@ -50,41 +50,6 @@ except ImportError as err:
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
"""
Forward pass implementation of neighborhood attention on the sphere (S2).
This function computes the neighborhood attention operation using sparse tensor
operations. It implements the attention mechanism with softmax normalization
and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi) where B is batch size, C is channels,
Hi is input height (latitude), Wi is input width (longitude)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo) where Ho is output height, Wo is output width
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor with shape (B, C, Ho, Wo) after neighborhood attention computation
"""
# prepare result tensor
y = torch.zeros_like(qy)
......@@ -135,41 +100,6 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
"""
Backward pass implementation for value gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the value tensor (vx).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the value tensor with shape (B, C, Hi, Wi)
"""
# shapes:
# input
......@@ -238,42 +168,6 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
"""
Backward pass implementation for key gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the key tensor (kx).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the key tensor with shape (B, C, Hi, Wi)
"""
# shapes:
# input
# kx: B, C, Hi, Wi
......@@ -354,41 +248,7 @@ def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor,
def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
"""
Backward pass implementation for query gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the query tensor (qy).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the query tensor with shape (B, C, Ho, Wo)
"""
# shapes:
# input
......@@ -581,52 +441,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
"""
Torch implementation of neighborhood attention on the sphere (S2).
This function provides a wrapper around the CPU autograd function for
neighborhood attention operations using sparse tensor computations.
Parameters
-----------
k : torch.Tensor
Key tensor
v : torch.Tensor
Value tensor
q : torch.Tensor
Query tensor
wk : torch.Tensor
Key weight tensor
wv : torch.Tensor
Value weight tensor
wq : torch.Tensor
Query weight tensor
bk : torch.Tensor or None
Key bias tensor (optional)
bv : torch.Tensor or None
Value bias tensor (optional)
bq : torch.Tensor or None
Query bias tensor (optional)
quad_weights : torch.Tensor
Quadrature weights for spherical integration
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nh : int
Number of attention heads
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor after neighborhood attention computation
"""
return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off,
nh, nlon_in, nlat_out, nlon_out)
......@@ -768,54 +583,7 @@ def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.T
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
"""
CUDA implementation of neighborhood attention on the sphere (S2).
This function provides a wrapper around the CUDA autograd function for
neighborhood attention operations using custom CUDA kernels for efficient GPU computation.
Parameters
-----------
k : torch.Tensor
Key tensor
v : torch.Tensor
Value tensor
q : torch.Tensor
Query tensor
wk : torch.Tensor
Key weight tensor
wv : torch.Tensor
Value weight tensor
wq : torch.Tensor
Query weight tensor
bk : torch.Tensor or None
Key bias tensor (optional)
bv : torch.Tensor or None
Value bias tensor (optional)
bq : torch.Tensor or None
Query bias tensor (optional)
quad_weights : torch.Tensor
Quadrature weights for spherical integration
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
max_psi_nnz : int
Maximum number of non-zero elements in sparse tensor
nh : int
Number of attention heads
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor after neighborhood attention computation
"""
return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off, max_psi_nnz,
nh, nlon_in, nlat_out, nlon_out)
......@@ -60,39 +60,30 @@ except ImportError as err:
def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
):
"""
Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
- "none": No normalization is applied.
- "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
- "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
Parameters
-----------
psi_idx: torch.Tensor
Index tensor of the convolution tensor
psi_vals: torch.Tensor
Values tensor of the convolution tensor
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
kernel_size: int
Size of the kernel
quad_weights: torch.Tensor
Quadrature weights
transpose_normalization: bool
Whether to normalize the convolution tensor in the transpose direction
basis_norm_mode: str
Mode for basis normalization
merge_quadrature: bool
Whether to merge the quadrature weights into the convolution tensor
eps: float
Small epsilon to avoid division by zero
Returns
-------
psi_vals: torch.Tensor
Normalized convolution tensor
"""Normalizes convolution tensor values based on specified normalization mode.
This function applies different normalization strategies to the convolution tensor
values based on the basis_norm_mode parameter. It can normalize individual basis
functions, compute mean normalization across all basis functions, or use support
weights. The function also optionally merges quadrature weights into the tensor.
Args:
psi_idx: Index tensor for the sparse convolution tensor.
psi_vals: Value tensor for the sparse convolution tensor.
in_shape: Tuple of (nlat_in, nlon_in) representing input grid dimensions.
out_shape: Tuple of (nlat_out, nlon_out) representing output grid dimensions.
kernel_size: Number of kernel basis functions.
quad_weights: Quadrature weights for numerical integration.
transpose_normalization: If True, applies normalization in transpose direction.
basis_norm_mode: Normalization mode, one of ["none", "individual", "mean", "support"].
merge_quadrature: If True, multiplies values by quadrature weights.
eps: Small epsilon value to prevent division by zero.
Returns:
torch.Tensor: Normalized convolution tensor values.
Raises:
ValueError: If basis_norm_mode is not one of the supported modes.
"""
# exit here if no normalization is needed
......
......@@ -39,17 +39,14 @@ from torch_harmonics.cache import lru_cache
def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
"""Helper function to compute the distance on a circle"""
return torch.minimum(torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2)))
def _log_factorial(x: torch.Tensor):
"""Helper function to compute the log factorial on a torch tensor"""
return torch.lgamma(x + 1)
def _factorial(x: torch.Tensor):
"""Helper function to compute the factorial on a torch tensor"""
return torch.exp(_log_factorial(x))
......@@ -62,27 +59,13 @@ class FilterBasis(metaclass=abc.ABCMeta):
self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
):
"""
Initialize the filter basis.
Parameters
-----------
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel, can be an integer or tuple of integers
"""
self.kernel_shape = kernel_shape
@property
@abc.abstractmethod
def kernel_size(self):
"""
Abstract property that should return the size of the kernel.
Returns
-------
kernel_size: int
The size of the kernel
"""
raise NotImplementedError
# @abc.abstractmethod
......@@ -94,10 +77,7 @@ class FilterBasis(metaclass=abc.ABCMeta):
@abc.abstractmethod
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the kernel's support and returns both indices and values.
This routine is designed for sparse evaluations of the filter basis.
"""
raise NotImplementedError
......@@ -124,12 +104,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
):
"""
Initialize the piecewise linear filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
if len(kernel_shape) == 1:
......@@ -152,9 +127,6 @@ class PiecewiseLinearFilterBasis(FilterBasis):
return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
......@@ -176,9 +148,6 @@ class PiecewiseLinearFilterBasis(FilterBasis):
return iidx, vals
def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
......@@ -253,14 +222,7 @@ class MorletFilterBasis(FilterBasis):
self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
):
"""
Initialize the Morlet filter basis.
Parameters
-----------
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape, kernel_shape]
if len(kernel_shape) != 2:
......@@ -270,56 +232,18 @@ class MorletFilterBasis(FilterBasis):
@property
def kernel_size(self):
"""
Compute the kernel size for Morlet basis.
Returns
-------
kernel_size: int
The size of the kernel
"""
return self.kernel_shape[0] * self.kernel_shape[1]
def gaussian_window(self, r: torch.Tensor, width: float = 1.0):
"""
Compute Gaussian window function.
Parameters
-----------
r: torch.Tensor
Radial distance tensor
width: float
Width parameter of the Gaussian
Returns
-------
out: torch.Tensor
Gaussian window values
"""
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def hann_window(self, r: torch.Tensor, width: float = 1.0):
"""
Compute Hann window function.
Parameters
-----------
r: torch.Tensor
Radial distance tensor
width: float
Width parameter of the Hann window
Returns
-------
out: torch.Tensor
Hann window values
"""
return torch.cos(0.5 * torch.pi * r / width) ** 2
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
......@@ -355,14 +279,7 @@ class ZernikeFilterBasis(FilterBasis):
self,
kernel_shape: Union[int, Tuple[int]],
):
"""
Initialize the Zernike filter basis.
Parameters
-----------
kernel_shape: Union[int, Tuple[int]]
Shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list):
kernel_shape = kernel_shape[0]
if not isinstance(kernel_shape, int):
......@@ -372,34 +289,11 @@ class ZernikeFilterBasis(FilterBasis):
@property
def kernel_size(self):
"""
Compute the kernel size for Zernike basis.
Returns
-------
kernel_size: int
The size of the kernel
"""
return (self.kernel_shape * (self.kernel_shape + 1)) // 2
def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor):
"""
Compute radial Zernike polynomials.
Parameters
-----------
r: torch.Tensor
Radial distance tensor
n: torch.Tensor
Principal quantum number
m: torch.Tensor
Azimuthal quantum number
Returns
-------
out: torch.Tensor
Radial Zernike polynomial values
"""
out = torch.zeros_like(r)
bound = (n - m) // 2 + 1
max_bound = bound.max().item()
......@@ -412,32 +306,11 @@ class ZernikeFilterBasis(FilterBasis):
return out
def zernikepoly(self, r: torch.Tensor, phi: torch.Tensor, n: torch.Tensor, l: torch.Tensor):
"""
Compute Zernike polynomials.
Parameters
-----------
r: torch.Tensor
Radial distance tensor
phi: torch.Tensor
Azimuthal angle tensor
n: torch.Tensor
Principal quantum number
l: torch.Tensor
Azimuthal quantum number
Returns
-------
out: torch.Tensor
Zernike polynomial values
"""
m = 2 * l - n
return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi))
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
......
......@@ -83,43 +83,13 @@ def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[floa
@lru_cache(typed=True, copy=True)
def _precompute_longitudes(nlon: int):
r"""
Convenience routine to precompute longitudes
Parameters
-----------
nlon: int
Number of longitude points
Returns
-------
lons: torch.Tensor
Tensor of longitude points
"""
lons = torch.linspace(0, 2 * math.pi, nlon+1, dtype=torch.float64, requires_grad=False)[:-1]
return lons
@lru_cache(typed=True, copy=True)
def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Convenience routine to precompute latitudes
Parameters
-----------
nlat: int
Number of latitude points
grid: Optional[str]
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
Returns
-------
lats: torch.Tensor
Tensor of latitude points
wlg: torch.Tensor
Tensor of quadrature weights
"""
# compute coordinates in the cosine theta domain
xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment