Unverified Commit c7afb546 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #95 from NVIDIA/aparis/docs

Docstrings PR
parents b5c410c0 644465ba
Pipeline #2854 canceled with stages
......@@ -60,11 +60,45 @@ except ImportError as err:
def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
):
"""
Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
- "none": No normalization is applied.
- "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
- "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
"""Normalizes convolution tensor values based on specified normalization mode.
This function applies different normalization strategies to the convolution tensor
values based on the basis_norm_mode parameter. It can normalize individual basis
functions, compute mean normalization across all basis functions, or use support
weights. The function also optionally merges quadrature weights into the tensor.
Parameters
-----------
psi_idx: torch.Tensor
Index tensor for the sparse convolution tensor.
psi_vals: torch.Tensor
Value tensor for the sparse convolution tensor.
in_shape: Tuple[int]
Tuple of (nlat_in, nlon_in) representing input grid dimensions.
out_shape: Tuple[int]
Tuple of (nlat_out, nlon_out) representing output grid dimensions.
kernel_size: int
Number of kernel basis functions.
quad_weights: torch.Tensor
Quadrature weights for numerical integration.
transpose_normalization: bool
If True, applies normalization in transpose direction.
basis_norm_mode: str
Normalization mode, one of ["none", "individual", "mean", "support"].
merge_quadrature: bool
If True, multiplies values by quadrature weights.
eps: float
Small epsilon value to prevent division by zero.
Returns
-------
torch.Tensor
Normalized convolution tensor values.
Raises
------
ValueError
If basis_norm_mode is not one of the supported modes.
"""
# exit here if no normalization is needed
......@@ -109,7 +143,6 @@ def _normalize_convolution_tensor_s2(
# compute the support
support[ik, ilat] = torch.sum(q[iidx])
# loop over values and renormalize
for ik in range(kernel_size):
for ilat in range(nlat_out):
......@@ -132,7 +165,6 @@ def _normalize_convolution_tensor_s2(
if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx]
if transpose_normalization and merge_quadrature:
psi_vals = psi_vals / correction_factor
......@@ -144,13 +176,13 @@ def _precompute_convolution_tensor_s2(
in_shape: Tuple[int],
out_shape: Tuple[int],
filter_basis: FilterBasis,
grid_in: Optional[str]="equiangular",
grid_out: Optional[str]="equiangular",
theta_cutoff: Optional[float]=0.01 * math.pi,
theta_eps: Optional[float]=1e-3,
transpose_normalization: Optional[bool]=False,
basis_norm_mode: Optional[str]="mean",
merge_quadrature: Optional[bool]=False,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
theta_cutoff: Optional[float] = 0.01 * math.pi,
theta_eps: Optional[float] = 1e-3,
transpose_normalization: Optional[bool] = False,
basis_norm_mode: Optional[str] = "mean",
merge_quadrature: Optional[bool] = False,
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
......@@ -166,6 +198,37 @@ def _precompute_convolution_tensor_s2(
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
Parameters
-----------
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
filter_basis: FilterBasis
Filter basis functions
grid_in: str
Input grid type
grid_out: str
Output grid type
theta_cutoff: float
Theta cutoff for the filter basis functions
theta_eps: float
Epsilon for the theta cutoff
transpose_normalization: bool
Whether to normalize the convolution tensor in the transpose direction
basis_norm_mode: str
Mode for basis normalization
merge_quadrature: bool
Whether to merge the quadrature weights into the convolution tensor
Returns
-------
out_idx: torch.Tensor
Index tensor of the convolution tensor
out_vals: torch.Tensor
Values tensor of the convolution tensor
"""
assert len(in_shape) == 2
......@@ -268,6 +331,26 @@ def _precompute_convolution_tensor_s2(
class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
"""
Abstract base class for discrete-continuous convolutions
Parameters
-----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of the basis functions
groups: Optional[int]
Number of groups
bias: Optional[bool]
Whether to use bias
Returns
-------
out: torch.Tensor
Output tensor
"""
def __init__(
......@@ -316,6 +399,40 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
"""
Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
Parameters
-----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of the basis functions
basis_norm_mode: Optional[str]
Mode for basis normalization
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Input grid type
grid_out: Optional[str]
Output grid type
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis functions
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
......@@ -382,9 +499,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property
......@@ -420,6 +534,40 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
"""
Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
Parameters
-----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of the basis functions
basis_norm_mode: Optional[str]
Mode for basis normalization
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Input grid type
grid_out: Optional[str]
Output grid type
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis functions
Returns
--------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
......@@ -487,9 +635,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property
......@@ -497,6 +642,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# extract shape
B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W)
......
......@@ -74,6 +74,32 @@ def _split_distributed_convolution_tensor_s2(
in_shape: Tuple[int],
out_shape: Tuple[int],
):
"""
Splits a pre-computed convolution tensor along the latitude dimension for distributed processing.
This function takes a convolution tensor that was generated by the serial routine and filters
it to only include entries corresponding to the local latitude slice assigned to this process.
The filtering is done based on the polar group rank and the computed split shapes.
Parameters
----------
idx: torch.Tensor
Indices of the pre-computed convolution tensor
vals: torch.Tensor
Values of the pre-computed convolution tensor
in_shape: Tuple[int]
Shape of the input tensor (nlat_in, nlon_in)
out_shape: Tuple[int]
Shape of the output tensor (nlat_out, nlon_out)
Returns
-------
idx: torch.Tensor
Filtered indices corresponding to the local latitude slice
vals: torch.Tensor
Filtered values corresponding to the local latitude slice
"""
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
......@@ -102,10 +128,43 @@ def _split_distributed_convolution_tensor_s2(
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
"""
Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
We assume the data can be splitted in polar and azimuthal directions.
Parameters
----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Shape of the input tensor
out_shape: Tuple[int]
Shape of the output tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of basis to use
basis_norm_mode: Optional[str]
Normalization mode for the filter basis
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Grid type for the input tensor
grid_out: Optional[str]
Grid type for the output tensor
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
We assume the data can be splitted in polar and azimuthal directions.
"""
def __init__(
......@@ -192,9 +251,6 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property
......@@ -247,6 +303,40 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
"""
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
Parameters
----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Shape of the input tensor
out_shape: Tuple[int]
Shape of the output tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of basis to use
basis_norm_mode: Optional[str]
Normalization mode for the filter basis
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Grid type for the input tensor
grid_out: Optional[str]
Grid type for the output tensor
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
We assume the data can be splitted in polar and azimuthal directions.
......@@ -339,9 +429,6 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local, semi_transposed=True)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property
......
......@@ -43,6 +43,32 @@ from torch_harmonics.distributed import compute_split_shapes
class DistributedResampleS2(nn.Module):
"""
Distributed resampling module for spherical data on the 2-sphere.
This module performs distributed resampling of spherical data across multiple processes,
supporting both upscaling and downscaling operations. The data is distributed across
polar and azimuthal directions, and the module handles the necessary communication
and interpolation operations.
Parameters
-----------
nlat_in : int
Number of input latitude points
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
mode : str, optional
Interpolation mode ("bilinear" or "bilinear-spherical"), by default "bilinear"
"""
def __init__(
self,
nlat_in: int,
......@@ -127,12 +153,10 @@ class DistributedResampleS2(nn.Module):
self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor):
"""Upscale the longitude dimension using interpolation."""
# do the interpolation
lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear":
......@@ -147,6 +171,7 @@ class DistributedResampleS2(nn.Module):
return x
def _expand_poles(self, x: torch.Tensor):
"""Expand the data to include pole values for interpolation."""
x_north = x[..., 0, :].sum(dim=-1, keepdims=True)
x_south = x[..., -1, :].sum(dim=-1, keepdims=True)
x_count = torch.tensor([x.shape[-1]], dtype=torch.long, device=x.device, requires_grad=False)
......@@ -169,6 +194,7 @@ class DistributedResampleS2(nn.Module):
return x
def _upscale_latitudes(self, x: torch.Tensor):
"""Upscale the latitude dimension using interpolation."""
# do the interpolation
lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear":
......
......@@ -48,20 +48,36 @@ class DistributedRealSHT(nn.Module):
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
"""
Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: grid in the latitude direction (for now only tensor product grids are supported)
"""
super().__init__()
self.nlat = nlat
......@@ -115,9 +131,6 @@ class DistributedRealSHT(nn.Module):
self.register_buffer('weights', weights, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
......@@ -168,9 +181,31 @@ class DistributedInverseRealSHT(nn.Module):
"""
Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
nlat, nlon: Output dimensions
lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
......@@ -226,9 +261,6 @@ class DistributedInverseRealSHT(nn.Module):
self.register_buffer('pct', pct, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
......@@ -282,19 +314,35 @@ class DistributedRealVectorSHT(nn.Module):
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input.
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: type of grid the data lives on
"""
super().__init__()
......@@ -355,9 +403,6 @@ class DistributedRealVectorSHT(nn.Module):
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
......@@ -425,6 +470,30 @@ class DistributedInverseRealVectorSHT(nn.Module):
Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
......@@ -478,9 +547,6 @@ class DistributedInverseRealVectorSHT(nn.Module):
self.register_buffer('dpct', dpct, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
......
......@@ -39,6 +39,7 @@ from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
"""Compute the split shapes for a given size and number of chunks."""
# treat trivial case first
if num_chunks == 1:
......@@ -59,6 +60,8 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
def split_tensor_along_dim(tensor, dim, num_chunks):
"""Split a tensor along a given dimension into a given number of chunks."""
assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
{num_chunks} chunks. Empty slices are currently not supported."
......@@ -71,6 +74,7 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
# get comm params
comm_size = dist.get_world_size(group=group)
comm_rank = dist.get_rank(group=group)
......@@ -99,6 +103,7 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x, dims, dim1_split_sizes):
# WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
x = torch.cat(xlist, dim=dims[1])
......@@ -124,6 +129,7 @@ class distributed_transpose_polar(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x, dim, dim1_split_sizes):
# WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
x = torch.cat(xlist, dim=dim[1])
......@@ -134,6 +140,7 @@ class distributed_transpose_polar(torch.autograd.Function):
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, go):
dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
......@@ -144,7 +151,6 @@ class distributed_transpose_polar(torch.autograd.Function):
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
......@@ -165,7 +171,6 @@ def _reduce(input_, use_fp32=True, group=None):
def _split(input_, dim_, group=None):
"""Split the tensor along its last dimension and keep the corresponding slice."""
# Bypass the function if we are using only 1 GPU.
comm_size = dist.get_world_size(group=group)
if comm_size == 1:
......@@ -182,7 +187,6 @@ def _split(input_, dim_, group=None):
def _gather(input_, dim_, shapes_, group=None):
"""Gather unevenly split tensors across ranks"""
comm_size = dist.get_world_size(group=group)
......@@ -215,7 +219,6 @@ def _gather(input_, dim_, shapes_, group=None):
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group and scatter it back."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
......@@ -244,7 +247,6 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
class _CopyToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod
def symbolic(graph, input_):
......@@ -253,11 +255,13 @@ class _CopyToPolarRegion(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_):
return input_
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
if is_distributed_polar():
return _reduce(grad_output, group=polar_group())
else:
......@@ -265,7 +269,6 @@ class _CopyToPolarRegion(torch.autograd.Function):
class _CopyToAzimuthRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod
def symbolic(graph, input_):
......@@ -274,11 +277,13 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_):
return input_
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
if is_distributed_azimuth():
return _reduce(grad_output, group=azimuth_group())
else:
......@@ -286,7 +291,6 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
class _ScatterToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod
def symbolic(graph, input_, dim_):
......@@ -314,7 +318,6 @@ class _ScatterToPolarRegion(torch.autograd.Function):
class _GatherFromPolarRegion(torch.autograd.Function):
"""Gather the input and keep it on the rank."""
@staticmethod
def symbolic(graph, input_, dim_, shapes_):
......@@ -339,7 +342,6 @@ class _GatherFromPolarRegion(torch.autograd.Function):
class _ReduceFromPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region."""
@staticmethod
def symbolic(graph, input_):
......@@ -363,8 +365,7 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
class _ReduceFromAzimuthRegion(torch.autograd.Function):
"""All-reduce the input from the azimuth region."""
@staticmethod
def symbolic(graph, input_):
if is_distributed_azimuth():
......@@ -387,7 +388,6 @@ class _ReduceFromAzimuthRegion(torch.autograd.Function):
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region and scatter back to polar region."""
@staticmethod
def symbolic(graph, input_, dim_):
......@@ -418,7 +418,6 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
class _GatherFromCopyToPolarRegion(torch.autograd.Function):
"""Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter"""
@staticmethod
def symbolic(graph, input_, dim_, shapes_):
......
......@@ -55,6 +55,27 @@ def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False,
class DiceLossS2(nn.Module):
"""
Dice loss for spherical segmentation tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
mode : str, optional
Aggregation mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100, mode: str = "micro"):
super().__init__()
......@@ -113,6 +134,24 @@ class DiceLossS2(nn.Module):
class CrossEntropyLossS2(nn.Module):
"""
Cross-entropy loss for spherical classification tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Label smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100):
......@@ -141,6 +180,24 @@ class CrossEntropyLossS2(nn.Module):
class FocalLossS2(nn.Module):
"""
Focal loss for spherical classification tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Label smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100):
......@@ -286,10 +343,19 @@ class NormalLossS2(SphericalLossBase):
Surface normals are computed by calculating gradients in latitude and longitude
directions using FFT, then constructing 3D normal vectors that are normalized.
Args:
nlat (int): Number of latitude points
nlon (int): Number of longitude points
grid (str, optional): Grid type. Defaults to "equiangular".
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
Returns
-------
torch.Tensor
Combined loss term
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
......
......@@ -49,6 +49,31 @@ def _get_stats_multiclass(
quad_weights: torch.Tensor,
ignore_index: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute multiclass statistics (TP, FP, FN, TN) on the sphere using quadrature weights.
This function computes true positives, false positives, false negatives, and true negatives
for multiclass classification on spherical data, properly weighted by quadrature weights
to account for the spherical geometry.
Parameters
-----------
output : torch.LongTensor
Predicted class labels
target : torch.LongTensor
Ground truth class labels
num_classes : int
Number of classes in the classification task
quad_weights : torch.Tensor
Quadrature weights for spherical integration
ignore_index : Optional[int]
Index to ignore in the computation (e.g., for padding or invalid regions)
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple containing (tp_count, fp_count, fn_count, tn_count) for each class
"""
batch_size, *dims = output.shape
num_elements = torch.prod(torch.tensor(dims)).long()
......@@ -88,10 +113,46 @@ def _get_stats_multiclass(
def _predict_classes(logits: torch.Tensor) -> torch.Tensor:
"""
Convert logits to class predictions using softmax and argmax.
Parameters
-----------
logits : torch.Tensor
Input logits tensor
Returns
-------
torch.Tensor
Predicted class labels
"""
return torch.argmax(torch.softmax(logits, dim=1), dim=1, keepdim=False)
class BaseMetricS2(nn.Module):
"""
Base class for spherical metrics that properly handle spherical geometry.
This class provides the foundation for computing metrics on spherical data
by using quadrature weights to account for the non-uniform area distribution
on the sphere.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__()
......@@ -108,6 +169,7 @@ class BaseMetricS2(nn.Module):
self.register_buffer("weight", weight.unsqueeze(0))
def _forward(self, pred: torch.Tensor, truth: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# convert logits to class predictions
pred_class = _predict_classes(pred)
......@@ -138,6 +200,28 @@ class BaseMetricS2(nn.Module):
class IntersectionOverUnionS2(BaseMetricS2):
"""
Intersection over Union (IoU) metric for spherical data.
Computes the IoU score for multiclass classification on the sphere,
properly weighted by quadrature weights to account for spherical geometry.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
......@@ -162,6 +246,28 @@ class IntersectionOverUnionS2(BaseMetricS2):
class AccuracyS2(BaseMetricS2):
"""
Accuracy metric for spherical data.
Computes the accuracy score for multiclass classification on the sphere,
properly weighted by quadrature weights to account for spherical geometry.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
......
......@@ -41,9 +41,11 @@ from torch_harmonics import InverseRealSHT
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
......@@ -75,19 +77,27 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
Parameters
-----------
tensor: torch.Tensor
an n-dimensional `torch.Tensor`
mean: float
the mean of the normal distribution
std: float
the standard deviation of the normal distribution
a: float
the minimum cutoff value, by default -2.0
b: float
the maximum cutoff value
Examples
--------
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
......@@ -102,6 +112,20 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
Parameters
----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Probability of dropping a path, by default 0.0
training : bool, optional
Whether the model is in training mode, by default False
Returns
-------
torch.Tensor
Output tensor
"""
if drop_prob == 0.0 or not training:
return x
......@@ -114,17 +138,47 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This module implements stochastic depth regularization by randomly dropping
entire residual paths during training, which helps with regularization and
training of very deep networks.
Parameters
----------
drop_prob : float, optional
Probability of dropping a path, by default None
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
"""
Patch embedding layer for vision transformers.
This module splits input images into patches and projects them to a
higher dimensional embedding space using convolutional layers.
Parameters
----------
img_size : tuple, optional
Input image size (height, width), by default (224, 224)
patch_size : tuple, optional
Patch size (height, width), by default (16, 16)
in_chans : int, optional
Number of input channels, by default 3
embed_dim : int, optional
Embedding dimension, by default 768
"""
def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
super(PatchEmbed, self).__init__()
self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1]))
......@@ -137,6 +191,7 @@ class PatchEmbed(nn.Module):
self.proj.bias.is_shared_mp = ["spatial"]
def forward(self, x):
# gather input
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
......@@ -146,6 +201,32 @@ class PatchEmbed(nn.Module):
class MLP(nn.Module):
"""
Multi-layer perceptron with optional checkpointing.
This module implements a feed-forward network with two linear layers
and an activation function, with optional dropout and gradient checkpointing.
Parameters
----------
in_features : int
Number of input features
hidden_features : int, optional
Number of hidden features, by default None (same as in_features)
out_features : int, optional
Number of output features, by default None (same as in_features)
act_layer : nn.Module, optional
Activation layer, by default nn.ReLU
output_bias : bool, optional
Whether to use bias in output layer, by default False
drop_rate : float, optional
Dropout rate, by default 0.0
checkpointing : bool, optional
Whether to use gradient checkpointing, by default False
gain : float, optional
Gain factor for weight initialization, by default 1.0
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
super(MLP, self).__init__()
self.checkpointing = checkpointing
......@@ -179,9 +260,11 @@ class MLP(nn.Module):
@torch.jit.ignore
def checkpoint_forward(self, x):
return checkpoint(self.fwd, x)
def forward(self, x):
if self.checkpointing:
return self.checkpoint_forward(x)
else:
......@@ -190,9 +273,23 @@ class MLP(nn.Module):
class RealFFT2(nn.Module):
"""
Helper routine to wrap FFT similarly to the SHT
Helper routine to wrap FFT similarly to the SHT.
This module provides a wrapper around PyTorch's real FFT2D that mimics
the interface of spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None):
super(RealFFT2, self).__init__()
......@@ -202,6 +299,7 @@ class RealFFT2(nn.Module):
self.mmax = mmax or self.nlon // 2 + 1
def forward(self, x):
y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
return y
......@@ -209,9 +307,23 @@ class RealFFT2(nn.Module):
class InverseRealFFT2(nn.Module):
"""
Helper routine to wrap FFT similarly to the SHT
Helper routine to wrap inverse FFT similarly to the SHT.
This module provides a wrapper around PyTorch's inverse real FFT2D that mimics
the interface of inverse spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None):
super(InverseRealFFT2, self).__init__()
......@@ -221,14 +333,34 @@ class InverseRealFFT2(nn.Module):
self.mmax = mmax or self.nlon // 2 + 1
def forward(self, x):
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class LayerNorm(nn.Module):
"""
Wrapper class that moves the channel dimension to the end
Wrapper class that moves the channel dimension to the end.
This module provides a layer normalization that works with channel-first
tensors by temporarily transposing the channel dimension to the end,
applying normalization, and then transposing back.
Parameters
----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type for the module, by default None
"""
def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
super().__init__()
......@@ -246,8 +378,27 @@ class SpectralConvS2(nn.Module):
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters
----------
forward_transform : nn.Module
Forward transform (SHT or FFT)
inverse_transform : nn.Module
Inverse transform (ISHT or IFFT)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
gain : float, optional
Gain factor for weight initialization, by default 2.0
operator_type : str, optional
Type of spectral operator ("driscoll-healy", "diagonal", "block-diagonal"), by default "driscoll-healy"
lr_scale_exponent : int, optional
Learning rate scaling exponent, by default 0
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
super().__init__()
......@@ -307,13 +458,25 @@ class SpectralConvS2(nn.Module):
return x, residual
class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
"""
Returns standard sequence based position embedding
Abstract base class for position embeddings.
This class defines the interface for position embedding modules
that add positional information to input tensors.
Parameters
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__()
self.img_shape = img_shape
......@@ -323,17 +486,28 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
return x + self.position_embeddings
class SequencePositionEmbedding(PositionEmbedding):
"""
Returns standard sequence based position embedding
Standard sequence-based position embedding.
This module implements sinusoidal position embeddings similar to those
used in the original Transformer paper, adapted for 2D spatial data.
Parameters
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
with torch.no_grad():
# alternating custom position embeddings
pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
......@@ -344,13 +518,26 @@ class SequencePositionEmbedding(PositionEmbedding):
# register tensor
self.register_buffer("position_embeddings", pos_embed.float())
class SpectralPositionEmbedding(PositionEmbedding):
"""
Returns position embeddings for the spherical transformer
Spectral position embeddings for spherical transformers.
This module creates position embeddings in the spectral domain using
spherical harmonic functions, which are particularly suitable for
spherical data processing.
Parameters
-----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# compute maximum required frequency and prepare isht
......@@ -382,11 +569,24 @@ class SpectralPositionEmbedding(PositionEmbedding):
class LearnablePositionEmbedding(PositionEmbedding):
"""
Returns position embeddings for the spherical transformer
Learnable position embeddings for spherical transformers.
This module provides learnable position embeddings that can be either
latitude-only or full latitude-longitude embeddings.
Parameters
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
embed_type : str, optional
Embedding type ("lat" or "latlon"), by default "lat"
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
if embed_type == "latlon":
......@@ -418,4 +618,4 @@ class LearnablePositionEmbedding(PositionEmbedding):
# pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))
# # register tensor
# self.register_buffer("position_embeddings", pos_embed.float())
\ No newline at end of file
# self.register_buffer("position_embeddings", pos_embed.float())
......@@ -50,6 +50,35 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module):
"""
Discrete-continuous encoder for spherical neural operators.
This module performs downsampling using discrete-continuous convolutions on the sphere,
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
inp_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(
self,
in_shape=(721, 1440),
......@@ -81,6 +110,7 @@ class DiscreteContinuousEncoder(nn.Module):
)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -92,6 +122,37 @@ class DiscreteContinuousEncoder(nn.Module):
class DiscreteContinuousDecoder(nn.Module):
"""
Discrete-continuous decoder for spherical neural operators.
This module performs upsampling using either spherical harmonic transforms or resampling,
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (721, 1440)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
inp_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsample_sht : bool, optional
Whether to use SHT for upsampling, by default False
"""
def __init__(
self,
in_shape=(480, 960),
......@@ -132,6 +193,7 @@ class DiscreteContinuousDecoder(nn.Module):
)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -146,6 +208,46 @@ class DiscreteContinuousDecoder(nn.Module):
class SphericalNeuralOperatorBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
conv_type : str, optional
Type of convolution to use, by default "local"
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
disco_kernel_shape : tuple, optional
Kernel shape for discrete-continuous convolution, by default (3, 3)
disco_basis_type : str, optional
Filter basis type for discrete-continuous convolution, by default "morlet"
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
"""
def __init__(
......@@ -279,53 +381,58 @@ class LocalSphericalNeuralOperator(nn.Module):
operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
as well as in the encoder and decoders.
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
----------
img_size : tuple, optional
Input image size (nlat, nlon), by default (128, 256)
grid : str, optional
Grid type for input/output, by default "equiangular"
grid_internal : str, optional
Grid type for internal processing, by default "legendre-gauss"
scale_factor : int, optional
Scale factor to use, by default 3
Scale factor for resolution changes, by default 3
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
embed_dim : int, optional
Dimension of the embeddings, by default 256
Embedding dimension, by default 256
num_layers : int, optional
Number of layers in the network, by default 4
Number of layers, by default 4
activation_function : str, optional
Activation function to use, by default "gelu"
encoder_kernel_shape : int, optional
size of the encoder kernel
filter_basis_type: Optional[str]: str, optional
filter basis type
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
Activation function name, by default "gelu"
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
encoder_kernel_shape : tuple, optional
Kernel shape for encoder, by default (3, 3)
filter_basis_type : str, optional
Filter basis type, by default "morlet"
use_mlp : bool, optional
Whether to use MLP layers, by default True
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
Drop path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
sfno_block_frequency : int, optional
Hopw often a (global) SFNO block is used, by default 2
Frequency of SFNO blocks, by default 2
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
big_skip : bool, optional
Whether to add a single large skip connection, by default True
pos_embed : bool, optional
Whether to use positional embedding, by default True
Hard thresholding fraction, by default 1.0
residual_prediction : bool, optional
Whether to use residual prediction, by default False
pos_embed : str, optional
Position embedding type, by default "none"
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
bias : bool, optional
Whether to use a bias, by default False
Example
-----------
----------
>>> model = LocalSphericalNeuralOperator(
... img_shape=(128, 256),
... scale_factor=4,
......@@ -338,7 +445,7 @@ class LocalSphericalNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256])
References
-----------
----------
.. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
"Neural Operators with Localized Integral and Differential Kernels" (2024).
ICML 2024, https://arxiv.org/pdf/2402.16845.
......@@ -497,6 +604,7 @@ class LocalSphericalNeuralOperator(nn.Module):
return x
def forward(self, x):
if self.residual_prediction:
residual = x
......
......@@ -54,6 +54,34 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
class OverlapPatchMerging(nn.Module):
"""
Overlap patch merging module for spherical segformer.
This module performs patch merging with overlapping patches using discrete-continuous
convolutions on the sphere, followed by layer normalization.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (481, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_channels : int, optional
Number of input channels, by default 3
out_channels : int, optional
Number of output channels, by default 64
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(
self,
in_shape=(721, 1440),
......@@ -89,11 +117,13 @@ class OverlapPatchMerging(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -109,6 +139,38 @@ class OverlapPatchMerging(nn.Module):
class MixFFN(nn.Module):
"""
Mix FFN module for spherical segformer.
This module implements a feed-forward network that combines MLP operations
with discrete-continuous convolutions on the sphere.
Parameters
-----------
shape : tuple
Shape (nlat, nlon) of the input
inout_channels : int
Number of input/output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
grid : str, optional
Grid type, by default "equiangular"
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : nn.Module, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
drop_path : float, optional
Drop path rate, by default 0.0
"""
def __init__(
self,
shape,
......@@ -161,6 +223,7 @@ class MixFFN(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
......@@ -170,7 +233,6 @@ class MixFFN(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
# norm
......@@ -194,6 +256,35 @@ class MixFFN(nn.Module):
class AttentionWrapper(nn.Module):
"""
Attention wrapper for spherical segformer.
This module wraps attention mechanisms (neighborhood or global) with optional
normalization and drop path regularization.
Parameters
-----------
channels : int
Number of channels
shape : tuple
Shape (nlat, nlon) of the input
grid : str
Grid type
heads : int
Number of attention heads
pre_norm : bool, optional
Whether to apply normalization before attention, by default False
attention_drop_rate : float, optional
Dropout rate for attention, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(
self,
channels,
......@@ -271,6 +362,49 @@ class AttentionWrapper(nn.Module):
class TransformerBlock(nn.Module):
"""
Transformer block for spherical segformer.
This block combines patch merging, attention, and Mix FFN operations
in a hierarchical structure for processing spherical data.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of repetitions, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.GELU
att_drop_rate : float, optional
Dropout rate for attention, by default 0.0
drop_path_rates : float, optional
Drop path rates, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(
self,
in_shape,
......@@ -374,6 +508,43 @@ class TransformerBlock(nn.Module):
class Upsampling(nn.Module):
"""
Upsampling module for spherical segformer.
This module performs upsampling using either discrete-continuous transposed convolutions
or bilinear resampling on spherical data.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : nn.Module, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
upsampling_method : str, optional
Upsampling method ("conv" or "bilinear"), by default "conv"
"""
def __init__(
self,
in_shape,
......@@ -429,6 +600,7 @@ class Upsampling(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample(self.mlp(x))
return x
......@@ -606,6 +778,7 @@ class SphericalSegformer(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
......
......@@ -52,6 +52,36 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module):
"""
Discrete-continuous encoder for spherical transformers.
This module performs downsampling using discrete-continuous convolutions on the sphere,
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(
self,
in_shape=(721, 1440),
......@@ -83,6 +113,7 @@ class DiscreteContinuousEncoder(nn.Module):
)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -94,6 +125,38 @@ class DiscreteContinuousEncoder(nn.Module):
class DiscreteContinuousDecoder(nn.Module):
"""
Discrete-continuous decoder for spherical transformers.
This module performs upsampling using either spherical harmonic transforms or resampling,
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (721, 1440)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsample_sht : bool, optional
Whether to use SHT for upsampling, by default False
"""
def __init__(
self,
in_shape=(480, 960),
......@@ -134,6 +197,7 @@ class DiscreteContinuousDecoder(nn.Module):
)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -147,7 +211,45 @@ class DiscreteContinuousDecoder(nn.Module):
class SphericalAttentionBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Spherical attention block for transformers on the sphere.
This module implements a single attention block that can use either global attention
or neighborhood attention on spherical data, followed by an optional MLP.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
num_heads : int, optional
Number of attention heads, by default 1
mlp_ratio : float, optional
Ratio of MLP hidden dimension to output dimension, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : nn.Module, optional
Activation layer, by default nn.GELU
norm_layer : str, optional
Normalization layer type, by default "none"
use_mlp : bool, optional
Whether to use MLP after attention, by default True
bias : bool, optional
Whether to use bias, by default False
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
"""
def __init__(
......@@ -467,6 +569,7 @@ class SphericalTransformer(nn.Module):
return x
def forward(self, x):
if self.residual_prediction:
residual = x
......
......@@ -54,6 +54,46 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
class DownsamplingBlock(nn.Module):
"""
Downsampling block for spherical U-Net architecture.
This block performs convolution operations followed by downsampling on spherical data,
using discrete-continuous convolutions to maintain spectral properties.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of convolution repetitions, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connection, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.0
drop_path_rate : float, optional
Drop path rate, by default 0.0
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.0
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__(
self,
in_shape,
......@@ -154,12 +194,14 @@ class DownsamplingBlock(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection
residual = x
if hasattr(self, "transform_skip"):
......@@ -178,6 +220,46 @@ class DownsamplingBlock(nn.Module):
class UpsamplingBlock(nn.Module):
"""
Upsampling block for spherical U-Net architecture.
This block performs upsampling followed by convolution operations on spherical data,
using discrete-continuous convolutions to maintain spectral properties.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of convolution repetitions, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connection, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.0
drop_path_rate : float, optional
Drop path rate, by default 0.0
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.0
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__(
self,
in_shape,
......@@ -496,6 +578,7 @@ class SphericalUNet(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
......
......@@ -43,6 +43,40 @@ from functools import partial
class SphericalFourierNeuralOperatorBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
"""
def __init__(
......@@ -118,7 +152,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else:
raise ValueError(f"Unknown skip connection type {outer_skip}")
def forward(self, x):
x, residual = self.global_conv(x)
......@@ -147,8 +180,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Parameters
----------
img_shape : tuple, optional
img_size : tuple, optional
Shape of the input channels, by default (128, 256)
grid : str, optional
Input grid type, by default "equiangular"
grid_internal : str, optional
Internal grid type for computations, by default "legendre-gauss"
scale_factor : int, optional
Scale factor to use, by default 3
in_chans : int, optional
......@@ -172,20 +209,20 @@ class SphericalFourierNeuralOperator(nn.Module):
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "none"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
residual_prediction : bool, optional
Whether to add a single large skip connection, by default True
pos_embed : bool, optional
Whether to use positional embedding, by default True
Whether to add a single large skip connection, by default False
pos_embed : str, optional
Type of positional embedding to use, by default "none"
bias : bool, optional
Whether to use a bias, by default False
Example:
--------
----------
>>> model = SphericalFourierNeuralOperator(
... img_shape=(128, 256),
... img_size=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
......@@ -196,7 +233,7 @@ class SphericalFourierNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256])
References
-----------
----------
.. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
ICML 2023, https://arxiv.org/abs/2306.03838.
......
......@@ -37,7 +37,38 @@ from .shallow_water_equations import ShallowWaterSolver
class PdeDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data"""
"""Custom Dataset class for PDE training data
Parameters
----------
dt : float
Time step
nsteps : int
Number of solver steps
dims : tuple, optional
Number of latitude and longitude points, by default (384, 768)
grid : str, optional
Grid type, by default "equiangular"
pde : str, optional
PDE type, by default "shallow water equations"
initial_condition : str, optional
Initial condition type, by default "random"
num_examples : int, optional
Number of examples, by default 32
device : torch.device, optional
Device to use, by default torch.device("cpu")
normalize : bool, optional
Whether to normalize the input and target, by default True
stream : torch.cuda.Stream, optional
CUDA stream to use, by default None
Returns
-------
inp : torch.Tensor
Input tensor
tar : torch.Tensor
Target tensor
"""
def __init__(
self,
......
......@@ -42,7 +42,27 @@ import numpy as np
class SphereSolver(nn.Module):
"""
Solver class on the sphere. Can solve the following PDEs:
- Allen-Cahn eq
- Allen-Cahn equation
- Ginzburg-Landau equation
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
dt : float
Time step size
lmax : int, optional
Maximum l mode for spherical harmonics, by default None
mmax : int, optional
Maximum m mode for spherical harmonics, by default None
grid : str, optional
Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
radius : float, optional
Radius of the sphere, by default 1.0
coeff : float, optional
Coefficient for the PDE, by default 0.001
"""
def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=1.0, coeff=0.001):
......@@ -97,17 +117,15 @@ class SphereSolver(nn.Module):
self.register_buffer('invlap', invlap)
def grid2spec(self, u):
"""spectral coefficients from spatial data"""
return self.sht(u)
def spec2grid(self, uspec):
"""spatial data from spectral coefficients"""
"""Convert spectral coefficients to spatial data."""
return self.isht(uspec)
def dudtspec(self, uspec, pde='allen-cahn'):
"""Compute the time derivative of spectral coefficients for different PDEs."""
if pde == 'allen-cahn':
ugrid = self.spec2grid(uspec)
u3spec = self.grid2spec(ugrid**3)
......@@ -117,20 +135,48 @@ class SphereSolver(nn.Module):
u3spec = self.grid2spec(ugrid**3)
dudtspec = uspec + (1. + 2.j)*self.coeff*self.lap*uspec - (1. + 2.j)*u3spec
else:
NotImplementedError
raise NotImplementedError(f"PDE type {pde} not implemented")
return dudtspec
def randspec(self):
"""random data on the sphere"""
"""Generate random spectral data on the sphere."""
rspec = torch.randn_like(self.lap) / 4 / torch.pi
return rspec
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
"""
plotting routine for data on the grid. Requires cartopy for 3d plots.
Plot data on the sphere grid. Requires cartopy for 3d plots.
Parameters
-----------
data : torch.Tensor
Data to plot
fig : matplotlib.figure.Figure
Figure to plot on
cmap : str, optional
Colormap name, by default 'twilight_shifted'
vmax : float, optional
Maximum value for color scaling, by default None
vmin : float, optional
Minimum value for color scaling, by default None
projection : str, optional
Projection type ("mollweide", "3d"), by default "3d"
title : str, optional
Plot title, by default None
antialiased : bool, optional
Whether to use antialiasing, by default False
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
Raises
------
NotImplementedError
If projection type is not supported
"""
import matplotlib.pyplot as plt
......@@ -172,9 +218,10 @@ class SphereSolver(nn.Module):
plt.title(title, y=1.05)
else:
raise NotImplementedError
raise NotImplementedError(f"Projection {projection} not implemented")
return im
def plot_specdata(self, data, fig, **kwargs):
"""Plot spectral data by converting to spatial data first."""
return self.plot_griddata(self.isht(data), fig, **kwargs)
......@@ -41,7 +41,35 @@ import numpy as np
class ShallowWaterSolver(nn.Module):
"""
SWE solver class. Interface inspired bu pyspharm and SHTns
Shallow Water Equations (SWE) solver class for spherical geometry.
Interface inspired by pyspharm and SHTns. Solves the shallow water equations
on a rotating sphere using spectral methods.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
dt : float
Time step size
lmax : int, optional
Maximum l mode for spherical harmonics, by default None
mmax : int, optional
Maximum m mode for spherical harmonics, by default None
grid : str, optional
Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
radius : float, optional
Radius of the sphere in meters, by default 6.37122E6 (Earth radius)
omega : float, optional
Angular velocity of rotation in rad/s, by default 7.292E-5 (Earth)
gravity : float, optional
Gravitational acceleration in m/s², by default 9.80616
havg : float, optional
Average height in meters, by default 10.e3
hamp : float, optional
Height amplitude in meters, by default 120.
"""
def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=6.37122E6, \
......@@ -114,58 +142,43 @@ class ShallowWaterSolver(nn.Module):
self.register_buffer('quad_weights', quad_weights)
def grid2spec(self, ugrid):
"""
spectral coefficients from spatial data
"""
"""Convert spatial data to spectral coefficients."""
return self.sht(ugrid)
def spec2grid(self, uspec):
"""
spatial data from spectral coefficients
"""
"""Convert spectral coefficients to spatial data."""
return self.isht(uspec)
def vrtdivspec(self, ugrid):
"""spatial data from spectral coefficients"""
"""Compute vorticity and divergence from velocity field."""
vrtdivspec = self.lap * self.radius * self.vsht(ugrid)
return vrtdivspec
def getuv(self, vrtdivspec):
"""
compute wind vector from spectral coeffs of vorticity and divergence
"""
"""Compute wind vector from spectral coefficients of vorticity and divergence."""
return self.ivsht( self.invlap * vrtdivspec / self.radius)
def gethuv(self, uspec):
"""
compute wind vector from spectral coeffs of vorticity and divergence
"""
"""Compute height and wind vector from spectral coefficients."""
hgrid = self.spec2grid(uspec[:1])
uvgrid = self.getuv(uspec[1:])
return torch.cat((hgrid, uvgrid), dim=-3)
def potential_vorticity(self, uspec):
"""
Compute potential vorticity
"""
"""Compute potential vorticity from spectral coefficients."""
ugrid = self.spec2grid(uspec)
pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0]
return pvrt
def dimensionless(self, uspec):
"""
Remove dimensions from variables
"""
"""Remove dimensions from variables for dimensionless analysis."""
uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r
uspec[1:] = uspec[1:] * self.radius / torch.sqrt(self.gravity * self.havg)
return uspec
def dudtspec(self, uspec):
"""
Compute time derivatives from solution represented in spectral coefficients
"""
"""Compute time derivatives from solution represented in spectral coefficients."""
dudtspec = torch.zeros_like(uspec)
# compute the derivatives - this should be incorporated into the solver:
......@@ -190,12 +203,7 @@ class ShallowWaterSolver(nn.Module):
return dudtspec
def galewsky_initial_condition(self):
"""
Initializes non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
"""
"""Initialize non-linear barotropically unstable shallow water test case."""
device = self.lap.device
umax = 80.
......@@ -233,9 +241,7 @@ class ShallowWaterSolver(nn.Module):
return torch.tril(uspec)
def random_initial_condition(self, mach=0.1) -> torch.Tensor:
"""
random initial condition on the sphere
"""
"""Generate random initial condition on the sphere."""
device = self.lap.device
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
......@@ -281,10 +287,7 @@ class ShallowWaterSolver(nn.Module):
return torch.tril(uspec)
def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor:
"""
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
"""
"""Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps."""
dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device)
# pointers to indicate the most current result
......@@ -316,6 +319,7 @@ class ShallowWaterSolver(nn.Module):
return uspec
def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0):
"""Integrate the solution on the grid."""
dlon = 2 * torch.pi / self.nlon
radius = 1 if dimensionless else self.radius
if polar_opt > 0:
......@@ -326,9 +330,7 @@ class ShallowWaterSolver(nn.Module):
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
"""
plotting routine for data on the grid. Requires cartopy for 3d plots.
"""
"""Plotting routine for data on the grid. Requires cartopy for 3d plots."""
import matplotlib.pyplot as plt
lons = self.lons.squeeze() - torch.pi
......
......@@ -58,21 +58,35 @@ class Stanford2D3DSDownloader:
"""
Convenience class for downloading the 2d3ds dataset [1].
Parameters
----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
Returns
-------
data_folders : list
List of extracted directory names
class_labels : list
List of semantic class labels
References
-----------
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
"""
def __init__(self, base_url: str = DEFAULT_BASE_URL, local_dir: str = "data"):
self.base_url = base_url
self.local_dir = local_dir
os.makedirs(self.local_dir, exist_ok=True)
def _download_file(self, filename):
import requests
from tqdm import tqdm
......@@ -106,6 +120,7 @@ class Stanford2D3DSDownloader:
return local_path
def _extract_tar(self, tar_path):
import tarfile
with tarfile.open(tar_path) as tar:
......@@ -116,7 +131,20 @@ class Stanford2D3DSDownloader:
return extracted_dir
def download_dataset(self, file_extracted_directory_pairs=DEFAULT_TAR_FILE_PAIRS):
"""
Download and extract the complete dataset.
Parameters
-----------
file_extracted_directory_pairs : list, optional
List of (filename, extracted_folder_name) pairs, by default DEFAULT_TAR_FILE_PAIRS
Returns
-------
tuple
(data_folders, class_labels) where data_folders is a list of extracted directory names
and class_labels is the semantic label mapping
"""
import requests
data_folders = []
......@@ -133,6 +161,7 @@ class Stanford2D3DSDownloader:
return data_folders, class_labels
def _rgb_to_id(self, img, class_labels_map, class_labels_indices):
# Convert to int32 first to avoid overflow
r = img[..., 0].astype(np.int32)
g = img[..., 1].astype(np.int32)
......@@ -167,7 +196,35 @@ class Stanford2D3DSDownloader:
downsampling_factor: int = 16,
remove_alpha_channel: bool = True,
):
"""
Convert the downloaded dataset to HDF5 format for efficient loading.
Parameters
-----------
data_folders : list
List of extracted data folder names
class_labels : list
List of semantic class labels
rgb_path : str, optional
Relative path to RGB images within each data folder, by default "pano/rgb"
semantic_path : str, optional
Relative path to semantic labels within each data folder, by default "pano/semantic"
depth_path : str, optional
Relative path to depth images within each data folder, by default "pano/depth"
output_filename : str, optional
Suffix for semantic label files, by default "semantic"
dataset_file : str, optional
Output HDF5 filename, by default "stanford_2d3ds_dataset.h5"
downsampling_factor : int, optional
Factor by which to downsample images, by default 16
remove_alpha_channel : bool, optional
Whether to remove alpha channel from RGB images, by default True
Returns
-------
str
Path to the created HDF5 dataset file
"""
converted_dataset_path = os.path.join(self.local_dir, dataset_file)
from PIL import Image
......@@ -391,8 +448,24 @@ class StanfordSegmentationDataset(Dataset):
"""
Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
Returns
-------
StanfordSegmentationDataset
Dataset object
References
-----------
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
......@@ -528,8 +601,19 @@ class StanfordDepthDataset(Dataset):
"""
Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
References
-----------
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
......@@ -620,7 +704,8 @@ class StanfordDepthDataset(Dataset):
def compute_stats_s2(dataset: Dataset, normalize_target: bool = False):
"""
Compute stats using parallel welford reduction and quadrature on the sphere. The parallel welford reduction follows this article (parallel algorithm): https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
Compute stats using parallel welford reduction and quadrature on the sphere.
The parallel welford reduction follows this article (parallel algorithm): https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
"""
nexamples = len(dataset)
......
......@@ -39,24 +39,19 @@ from torch_harmonics.cache import lru_cache
def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
"""Helper function to compute the distance on a circle"""
return torch.minimum(torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2)))
def _log_factorial(x: torch.Tensor):
"""Helper function to compute the log factorial on a torch tensor"""
return torch.lgamma(x + 1)
def _factorial(x: torch.Tensor):
"""Helper function to compute the factorial on a torch tensor"""
return torch.exp(_log_factorial(x))
class FilterBasis(metaclass=abc.ABCMeta):
"""
Abstract base class for a filter basis
"""
"""Abstract base class for a filter basis"""
def __init__(
self,
......@@ -68,6 +63,7 @@ class FilterBasis(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def kernel_size(self):
raise NotImplementedError
# @abc.abstractmethod
......@@ -79,10 +75,7 @@ class FilterBasis(metaclass=abc.ABCMeta):
@abc.abstractmethod
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the kernel's support and returns both indices and values.
This routine is designed for sparse evaluations of the filter basis.
"""
raise NotImplementedError
......@@ -101,9 +94,7 @@ def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basi
class PiecewiseLinearFilterBasis(FilterBasis):
"""
Tensor-product basis on a disk constructed from piecewise linear basis functions.
"""
"""Tensor-product basis on a disk constructed from piecewise linear basis functions."""
def __init__(
self,
......@@ -121,12 +112,10 @@ class PiecewiseLinearFilterBasis(FilterBasis):
@property
def kernel_size(self):
"""Compute the number of basis functions in the kernel."""
return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
......@@ -148,9 +137,6 @@ class PiecewiseLinearFilterBasis(FilterBasis):
return iidx, vals
def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
......@@ -209,6 +195,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
return iidx, vals
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""Computes the index set that falls into the kernel's support and returns both indices and values."""
if self.kernel_shape[1] > 1:
return self._compute_support_vals_anisotropic(r, phi, r_cutoff=r_cutoff)
......@@ -217,9 +204,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
class MorletFilterBasis(FilterBasis):
"""
Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions
"""
"""Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions."""
def __init__(
self,
......@@ -235,18 +220,18 @@ class MorletFilterBasis(FilterBasis):
@property
def kernel_size(self):
return self.kernel_shape[0] * self.kernel_shape[1]
def gaussian_window(self, r: torch.Tensor, width: float = 1.0):
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def hann_window(self, r: torch.Tensor, width: float = 1.0):
return torch.cos(0.5 * torch.pi * r / width) ** 2
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
......@@ -274,9 +259,7 @@ class MorletFilterBasis(FilterBasis):
class ZernikeFilterBasis(FilterBasis):
"""
Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials
"""
"""Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials"""
def __init__(
self,
......@@ -292,9 +275,11 @@ class ZernikeFilterBasis(FilterBasis):
@property
def kernel_size(self):
return (self.kernel_shape * (self.kernel_shape + 1)) // 2
def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor):
out = torch.zeros_like(r)
bound = (n - m) // 2 + 1
max_bound = bound.max().item()
......@@ -307,13 +292,11 @@ class ZernikeFilterBasis(FilterBasis):
return out
def zernikepoly(self, r: torch.Tensor, phi: torch.Tensor, n: torch.Tensor, l: torch.Tensor):
m = 2 * l - n
return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi))
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
......
......@@ -37,18 +37,37 @@ from torch_harmonics.cache import lru_cache
def clm(l: int, m: int) -> float:
"""
defines the normalization factor to orthonormalize the Spherical Harmonics
"""
"""Defines the normalization factor to orthonormalize the Spherical Harmonics."""
return math.sqrt((2*l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l-m) / math.factorial(l+m))
def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r"""
"""
Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.
method of computation follows
Parameters
-----------
mmax: int
Maximum order of the spherical harmonics
lmax: int
Maximum degree of the spherical harmonics
x: torch.Tensor
Tensor of positions at which to evaluate the Legendre polynomials
norm: Optional[str]
Normalization of the Legendre polynomials
inverse: Optional[bool]
Whether to compute the inverse Legendre polynomials
csphase: Optional[bool]
Whether to apply the Condon-Shortley phase (-1)^m
Returns
-------
out: torch.Tensor
Tensor of Legendre polynomial values
References
----------
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982;
https://apps.dtic.mil/sti/citations/ADA123406
......@@ -94,12 +113,33 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
@lru_cache(typed=True, copy=True)
def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r"""
"""
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta).
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.
method of computation follows
Parameters
-----------
mmax: int
Maximum order of the spherical harmonics
lmax: int
Maximum degree of the spherical harmonics
t: torch.Tensor
Tensor of positions at which to evaluate the Legendre polynomials
norm: Optional[str]
Normalization of the Legendre polynomials
inverse: Optional[bool]
Whether to compute the inverse Legendre polynomials
csphase: Optional[bool]
Whether to apply the Condon-Shortley phase (-1)^m
Returns
-------
out: torch.Tensor
Tensor of Legendre polynomial values
References
----------
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982;
https://apps.dtic.mil/sti/citations/ADA123406
......@@ -111,13 +151,34 @@ def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
@lru_cache(typed=True, copy=True)
def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r"""
"""
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
needed for the computation of the vector spherical harmonics. The resulting tensor has shape
(2, mmax, lmax, len(t)).
computation follows
Parameters
-----------
mmax: int
Maximum order of the spherical harmonics
lmax: int
Maximum degree of the spherical harmonics
t: torch.Tensor
Tensor of positions at which to evaluate the Legendre polynomials
norm: Optional[str]
Normalization of the Legendre polynomials
inverse: Optional[bool]
Whether to compute the inverse Legendre polynomials
csphase: Optional[bool]
Whether to apply the Condon-Shortley phase (-1)^m
Returns
-------
out: torch.Tensor
Tensor of Legendre polynomial values
References
----------
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
......
......@@ -58,6 +58,28 @@ def get_projection(
central_latitude=0,
central_longitude=0,
):
"""
Get a cartopy projection object for map plotting.
Parameters
-----------
projection : str
Projection type ("orthographic", "robinson", "platecarree", "mollweide")
central_latitude : float, optional
Central latitude for the projection, by default 0
central_longitude : float, optional
Central longitude for the projection, by default 0
Returns
-------
cartopy.crs.Projection
Cartopy projection object
Raises
------
ValueError
If projection type is not supported
"""
if projection == "orthographic":
proj = ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude)
elif projection == "robinson":
......@@ -77,6 +99,40 @@ def plot_sphere(
):
"""
Plots a function defined on the sphere using pcolormesh
Parameters
-----------
data : numpy.ndarray or torch.Tensor
Data to plot with shape (nlat, nlon)
fig : matplotlib.figure.Figure, optional
Figure to plot on, by default None (creates new figure)
projection : str, optional
Map projection type, by default "robinson"
cmap : str, optional
Colormap name, by default "RdBu"
title : str, optional
Plot title, by default None
colorbar : bool, optional
Whether to add a colorbar, by default False
coastlines : bool, optional
Whether to add coastlines, by default False
gridlines : bool, optional
Whether to add gridlines, by default False
central_latitude : float, optional
Central latitude for projection, by default 0
central_longitude : float, optional
Central longitude for projection, by default 0
lon : numpy.ndarray, optional
Longitude coordinates, by default None (auto-generated)
lat : numpy.ndarray, optional
Latitude coordinates, by default None (auto-generated)
**kwargs
Additional arguments passed to pcolormesh
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
"""
# make sure cartopy exist
......@@ -126,6 +182,28 @@ def plot_sphere(
def imshow_sphere(data, fig=None, projection="robinson", title=None, central_latitude=0, central_longitude=0, **kwargs):
"""
Displays an image on the sphere
Parameters
-----------
data : numpy.ndarray or torch.Tensor
Data to display with shape (nlat, nlon)
fig : matplotlib.figure.Figure, optional
Figure to plot on, by default None (creates new figure)
projection : str, optional
Map projection type, by default "robinson"
title : str, optional
Plot title, by default None
central_latitude : float, optional
Central latitude for projection, by default 0
central_longitude : float, optional
Central longitude for projection, by default 0
**kwargs
Additional arguments passed to imshow
Returns
-------
matplotlib.image.AxesImage
The displayed image object
"""
# make sure cartopy exist
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment