"""Normalizes convolution tensor values based on specified normalization mode.
Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
- "none": No normalization is applied.
This function applies different normalization strategies to the convolution tensor
- "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
values based on the basis_norm_mode parameter. It can normalize individual basis
- "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
functions, compute mean normalization across all basis functions, or use support
weights. The function also optionally merges quadrature weights into the tensor.
Parameters
-----------
Args:
psi_idx: torch.Tensor
psi_idx: Index tensor for the sparse convolution tensor.
Index tensor of the convolution tensor
psi_vals: Value tensor for the sparse convolution tensor.
psi_vals: torch.Tensor
in_shape: Tuple of (nlat_in, nlon_in) representing input grid dimensions.
Values tensor of the convolution tensor
out_shape: Tuple of (nlat_out, nlon_out) representing output grid dimensions.
in_shape: Tuple[int]
kernel_size: Number of kernel basis functions.
Input shape of the convolution tensor
quad_weights: Quadrature weights for numerical integration.
out_shape: Tuple[int]
transpose_normalization: If True, applies normalization in transpose direction.
Output shape of the convolution tensor
basis_norm_mode: Normalization mode, one of ["none", "individual", "mean", "support"].
kernel_size: int
merge_quadrature: If True, multiplies values by quadrature weights.
Size of the kernel
eps: Small epsilon value to prevent division by zero.