Commit 43e3e720 authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

updated docstring

parent c44d4b23
......@@ -75,49 +75,29 @@ def _split_distributed_convolution_tensor_s2(
out_shape: Tuple[int],
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).
The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
$$
Y(\alpha) Z(\beta) Y(\gamma) n =
{\begin{bmatrix}
\cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
\sin(\beta)\sin(\gamma) \\
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
Splits a pre-computed convolution tensor along the latitude dimension for distributed processing.
This function takes a convolution tensor that was generated by the serial routine and filters
it to only include entries corresponding to the local latitude slice assigned to this process.
The filtering is done based on the polar group rank and the computed split shapes.
Parameters
----------
idx: torch.Tensor
Indices of the pre-computed convolution tensor
vals: torch.Tensor
Values of the pre-computed convolution tensor
in_shape: Tuple[int]
Shape of the input tensor
Shape of the input tensor (nlat_in, nlon_in)
out_shape: Tuple[int]
Shape of the output tensor
filter_basis: FilterBasis
Filter basis to use
grid_in: str
Grid type for the input tensor
grid_out: str
Grid type for the output tensor
theta_cutoff: float
Theta cutoff for the filter basis
theta_eps: float
Epsilon for the theta cutoff
transpose_normalization: bool
Whether to transpose the normalization
basis_norm_mode: str
Normalization mode for the filter basis
merge_quadrature: bool
Whether to merge the quadrature weights
Shape of the output tensor (nlat_out, nlon_out)
Returns
-------
out_idx: torch.Tensor
Indices of the output tensor
out_vals: torch.Tensor
Values of the output tensor
idx: torch.Tensor
Filtered indices corresponding to the local latitude slice
vals: torch.Tensor
Filtered values corresponding to the local latitude slice
"""
assert len(in_shape) == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment