Unverified Commit c7afb546 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #95 from NVIDIA/aparis/docs

Docstrings PR
parents b5c410c0 644465ba
Pipeline #2854 canceled with stages
...@@ -60,11 +60,45 @@ except ImportError as err: ...@@ -60,11 +60,45 @@ except ImportError as err:
def _normalize_convolution_tensor_s2( def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9 psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
): ):
""" """Normalizes convolution tensor values based on specified normalization mode.
Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
- "none": No normalization is applied. This function applies different normalization strategies to the convolution tensor
- "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1. values based on the basis_norm_mode parameter. It can normalize individual basis
- "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean. functions, compute mean normalization across all basis functions, or use support
weights. The function also optionally merges quadrature weights into the tensor.
Parameters
-----------
psi_idx: torch.Tensor
Index tensor for the sparse convolution tensor.
psi_vals: torch.Tensor
Value tensor for the sparse convolution tensor.
in_shape: Tuple[int]
Tuple of (nlat_in, nlon_in) representing input grid dimensions.
out_shape: Tuple[int]
Tuple of (nlat_out, nlon_out) representing output grid dimensions.
kernel_size: int
Number of kernel basis functions.
quad_weights: torch.Tensor
Quadrature weights for numerical integration.
transpose_normalization: bool
If True, applies normalization in transpose direction.
basis_norm_mode: str
Normalization mode, one of ["none", "individual", "mean", "support"].
merge_quadrature: bool
If True, multiplies values by quadrature weights.
eps: float
Small epsilon value to prevent division by zero.
Returns
-------
torch.Tensor
Normalized convolution tensor values.
Raises
------
ValueError
If basis_norm_mode is not one of the supported modes.
""" """
# exit here if no normalization is needed # exit here if no normalization is needed
...@@ -109,7 +143,6 @@ def _normalize_convolution_tensor_s2( ...@@ -109,7 +143,6 @@ def _normalize_convolution_tensor_s2(
# compute the support # compute the support
support[ik, ilat] = torch.sum(q[iidx]) support[ik, ilat] = torch.sum(q[iidx])
# loop over values and renormalize # loop over values and renormalize
for ik in range(kernel_size): for ik in range(kernel_size):
for ilat in range(nlat_out): for ilat in range(nlat_out):
...@@ -132,7 +165,6 @@ def _normalize_convolution_tensor_s2( ...@@ -132,7 +165,6 @@ def _normalize_convolution_tensor_s2(
if merge_quadrature: if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx] psi_vals[iidx] = psi_vals[iidx] * q[iidx]
if transpose_normalization and merge_quadrature: if transpose_normalization and merge_quadrature:
psi_vals = psi_vals / correction_factor psi_vals = psi_vals / correction_factor
...@@ -144,13 +176,13 @@ def _precompute_convolution_tensor_s2( ...@@ -144,13 +176,13 @@ def _precompute_convolution_tensor_s2(
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
filter_basis: FilterBasis, filter_basis: FilterBasis,
grid_in: Optional[str]="equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str]="equiangular", grid_out: Optional[str] = "equiangular",
theta_cutoff: Optional[float]=0.01 * math.pi, theta_cutoff: Optional[float] = 0.01 * math.pi,
theta_eps: Optional[float]=1e-3, theta_eps: Optional[float] = 1e-3,
transpose_normalization: Optional[bool]=False, transpose_normalization: Optional[bool] = False,
basis_norm_mode: Optional[str]="mean", basis_norm_mode: Optional[str] = "mean",
merge_quadrature: Optional[bool]=False, merge_quadrature: Optional[bool] = False,
): ):
""" """
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
...@@ -166,6 +198,37 @@ def _precompute_convolution_tensor_s2( ...@@ -166,6 +198,37 @@ def _precompute_convolution_tensor_s2(
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma) \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}} \end{bmatrix}}
$$ $$
Parameters
-----------
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
filter_basis: FilterBasis
Filter basis functions
grid_in: str
Input grid type
grid_out: str
Output grid type
theta_cutoff: float
Theta cutoff for the filter basis functions
theta_eps: float
Epsilon for the theta cutoff
transpose_normalization: bool
Whether to normalize the convolution tensor in the transpose direction
basis_norm_mode: str
Mode for basis normalization
merge_quadrature: bool
Whether to merge the quadrature weights into the convolution tensor
Returns
-------
out_idx: torch.Tensor
Index tensor of the convolution tensor
out_vals: torch.Tensor
Values tensor of the convolution tensor
""" """
assert len(in_shape) == 2 assert len(in_shape) == 2
...@@ -268,6 +331,26 @@ def _precompute_convolution_tensor_s2( ...@@ -268,6 +331,26 @@ def _precompute_convolution_tensor_s2(
class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta): class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
""" """
Abstract base class for discrete-continuous convolutions Abstract base class for discrete-continuous convolutions
Parameters
-----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of the basis functions
groups: Optional[int]
Number of groups
bias: Optional[bool]
Whether to use bias
Returns
-------
out: torch.Tensor
Output tensor
""" """
def __init__( def __init__(
...@@ -316,6 +399,40 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -316,6 +399,40 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
""" """
Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1]. Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
Parameters
-----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of the basis functions
basis_norm_mode: Optional[str]
Mode for basis normalization
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Input grid type
grid_out: Optional[str]
Output grid type
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis functions
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
""" """
...@@ -382,9 +499,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -382,9 +499,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out) self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
...@@ -420,6 +534,40 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -420,6 +534,40 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
""" """
Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1]. Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
Parameters
-----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Input shape of the convolution tensor
out_shape: Tuple[int]
Output shape of the convolution tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of the basis functions
basis_norm_mode: Optional[str]
Mode for basis normalization
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Input grid type
grid_out: Optional[str]
Output grid type
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis functions
Returns
--------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
""" """
...@@ -487,9 +635,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -487,9 +635,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True) self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
...@@ -497,6 +642,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -497,6 +642,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# extract shape # extract shape
B, C, H, W = x.shape B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W) x = x.reshape(B, self.groups, self.groupsize, H, W)
......
...@@ -74,6 +74,32 @@ def _split_distributed_convolution_tensor_s2( ...@@ -74,6 +74,32 @@ def _split_distributed_convolution_tensor_s2(
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
): ):
"""
Splits a pre-computed convolution tensor along the latitude dimension for distributed processing.
This function takes a convolution tensor that was generated by the serial routine and filters
it to only include entries corresponding to the local latitude slice assigned to this process.
The filtering is done based on the polar group rank and the computed split shapes.
Parameters
----------
idx: torch.Tensor
Indices of the pre-computed convolution tensor
vals: torch.Tensor
Values of the pre-computed convolution tensor
in_shape: Tuple[int]
Shape of the input tensor (nlat_in, nlon_in)
out_shape: Tuple[int]
Shape of the output tensor (nlat_out, nlon_out)
Returns
-------
idx: torch.Tensor
Filtered indices corresponding to the local latitude slice
vals: torch.Tensor
Filtered values corresponding to the local latitude slice
"""
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
...@@ -102,10 +128,43 @@ def _split_distributed_convolution_tensor_s2( ...@@ -102,10 +128,43 @@ def _split_distributed_convolution_tensor_s2(
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
""" """
Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
We assume the data can be splitted in polar and azimuthal directions.
Parameters
----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Shape of the input tensor
out_shape: Tuple[int]
Shape of the output tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of basis to use
basis_norm_mode: Optional[str]
Normalization mode for the filter basis
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Grid type for the input tensor
grid_out: Optional[str]
Grid type for the output tensor
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
We assume the data can be splitted in polar and azimuthal directions.
""" """
def __init__( def __init__(
...@@ -192,9 +251,6 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -192,9 +251,6 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local) self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
...@@ -247,6 +303,40 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -247,6 +303,40 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
""" """
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1]. Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
Parameters
----------
in_channels: int
Number of input channels
out_channels: int
Number of output channels
in_shape: Tuple[int]
Shape of the input tensor
out_shape: Tuple[int]
Shape of the output tensor
kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
Shape of the kernel
basis_type: Optional[str]
Type of basis to use
basis_norm_mode: Optional[str]
Normalization mode for the filter basis
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Grid type for the input tensor
grid_out: Optional[str]
Grid type for the output tensor
bias: Optional[bool]
Whether to use bias
theta_cutoff: Optional[float]
Theta cutoff for the filter basis
Returns
-------
out: torch.Tensor
Output tensor
References
----------
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
We assume the data can be splitted in polar and azimuthal directions. We assume the data can be splitted in polar and azimuthal directions.
...@@ -339,9 +429,6 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -339,9 +429,6 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local, semi_transposed=True) self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local, semi_transposed=True)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
......
...@@ -43,6 +43,32 @@ from torch_harmonics.distributed import compute_split_shapes ...@@ -43,6 +43,32 @@ from torch_harmonics.distributed import compute_split_shapes
class DistributedResampleS2(nn.Module): class DistributedResampleS2(nn.Module):
"""
Distributed resampling module for spherical data on the 2-sphere.
This module performs distributed resampling of spherical data across multiple processes,
supporting both upscaling and downscaling operations. The data is distributed across
polar and azimuthal directions, and the module handles the necessary communication
and interpolation operations.
Parameters
-----------
nlat_in : int
Number of input latitude points
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
mode : str, optional
Interpolation mode ("bilinear" or "bilinear-spherical"), by default "bilinear"
"""
def __init__( def __init__(
self, self,
nlat_in: int, nlat_in: int,
...@@ -127,12 +153,10 @@ class DistributedResampleS2(nn.Module): ...@@ -127,12 +153,10 @@ class DistributedResampleS2(nn.Module):
self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out) self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)
def extra_repr(self): def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor): def _upscale_longitudes(self, x: torch.Tensor):
"""Upscale the longitude dimension using interpolation."""
# do the interpolation # do the interpolation
lwgt = self.lon_weights.to(x.dtype) lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear": if self.mode == "bilinear":
...@@ -147,6 +171,7 @@ class DistributedResampleS2(nn.Module): ...@@ -147,6 +171,7 @@ class DistributedResampleS2(nn.Module):
return x return x
def _expand_poles(self, x: torch.Tensor): def _expand_poles(self, x: torch.Tensor):
"""Expand the data to include pole values for interpolation."""
x_north = x[..., 0, :].sum(dim=-1, keepdims=True) x_north = x[..., 0, :].sum(dim=-1, keepdims=True)
x_south = x[..., -1, :].sum(dim=-1, keepdims=True) x_south = x[..., -1, :].sum(dim=-1, keepdims=True)
x_count = torch.tensor([x.shape[-1]], dtype=torch.long, device=x.device, requires_grad=False) x_count = torch.tensor([x.shape[-1]], dtype=torch.long, device=x.device, requires_grad=False)
...@@ -169,6 +194,7 @@ class DistributedResampleS2(nn.Module): ...@@ -169,6 +194,7 @@ class DistributedResampleS2(nn.Module):
return x return x
def _upscale_latitudes(self, x: torch.Tensor): def _upscale_latitudes(self, x: torch.Tensor):
"""Upscale the latitude dimension using interpolation."""
# do the interpolation # do the interpolation
lwgt = self.lat_weights.to(x.dtype) lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear": if self.mode == "bilinear":
......
...@@ -48,19 +48,35 @@ class DistributedRealSHT(nn.Module): ...@@ -48,19 +48,35 @@ class DistributedRealSHT(nn.Module):
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input The SHT is applied to the last two dimensions of the input
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True): def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
"""
Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: grid in the latitude direction (for now only tensor product grids are supported)
"""
super().__init__() super().__init__()
...@@ -115,9 +131,6 @@ class DistributedRealSHT(nn.Module): ...@@ -115,9 +131,6 @@ class DistributedRealSHT(nn.Module):
self.register_buffer('weights', weights, persistent=False) self.register_buffer('weights', weights, persistent=False)
def extra_repr(self): def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
...@@ -168,9 +181,31 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -168,9 +181,31 @@ class DistributedInverseRealSHT(nn.Module):
""" """
Defines a module for computing the inverse (real-valued) SHT. Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
nlat, nlon: Output dimensions
lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
...@@ -226,9 +261,6 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -226,9 +261,6 @@ class DistributedInverseRealSHT(nn.Module):
self.register_buffer('pct', pct, persistent=False) self.register_buffer('pct', pct, persistent=False)
def extra_repr(self): def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
...@@ -282,19 +314,35 @@ class DistributedRealVectorSHT(nn.Module): ...@@ -282,19 +314,35 @@ class DistributedRealVectorSHT(nn.Module):
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input. The SHT is applied to the last three dimensions of the input.
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True): def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: type of grid the data lives on
"""
super().__init__() super().__init__()
...@@ -355,9 +403,6 @@ class DistributedRealVectorSHT(nn.Module): ...@@ -355,9 +403,6 @@ class DistributedRealVectorSHT(nn.Module):
def extra_repr(self): def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
...@@ -425,6 +470,30 @@ class DistributedInverseRealVectorSHT(nn.Module): ...@@ -425,6 +470,30 @@ class DistributedInverseRealVectorSHT(nn.Module):
Defines a module for computing the inverse (real-valued) vector SHT. Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
...@@ -478,9 +547,6 @@ class DistributedInverseRealVectorSHT(nn.Module): ...@@ -478,9 +547,6 @@ class DistributedInverseRealVectorSHT(nn.Module):
self.register_buffer('dpct', dpct, persistent=False) self.register_buffer('dpct', dpct, persistent=False)
def extra_repr(self): def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
......
...@@ -39,6 +39,7 @@ from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth ...@@ -39,6 +39,7 @@ from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
# helper routine to compute uneven splitting in balanced way: # helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]: def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
"""Compute the split shapes for a given size and number of chunks."""
# treat trivial case first # treat trivial case first
if num_chunks == 1: if num_chunks == 1:
...@@ -59,6 +60,8 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]: ...@@ -59,6 +60,8 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
def split_tensor_along_dim(tensor, dim, num_chunks): def split_tensor_along_dim(tensor, dim, num_chunks):
"""Split a tensor along a given dimension into a given number of chunks."""
assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
{num_chunks} chunks. Empty slices are currently not supported." {num_chunks} chunks. Empty slices are currently not supported."
...@@ -71,6 +74,7 @@ def split_tensor_along_dim(tensor, dim, num_chunks): ...@@ -71,6 +74,7 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False): def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
# get comm params # get comm params
comm_size = dist.get_world_size(group=group) comm_size = dist.get_world_size(group=group)
comm_rank = dist.get_rank(group=group) comm_rank = dist.get_rank(group=group)
...@@ -99,6 +103,7 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -99,6 +103,7 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
def forward(ctx, x, dims, dim1_split_sizes): def forward(ctx, x, dims, dim1_split_sizes):
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group()) xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
x = torch.cat(xlist, dim=dims[1]) x = torch.cat(xlist, dim=dims[1])
...@@ -124,6 +129,7 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -124,6 +129,7 @@ class distributed_transpose_polar(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
def forward(ctx, x, dim, dim1_split_sizes): def forward(ctx, x, dim, dim1_split_sizes):
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group()) xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
x = torch.cat(xlist, dim=dim[1]) x = torch.cat(xlist, dim=dim[1])
...@@ -134,6 +140,7 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -134,6 +140,7 @@ class distributed_transpose_polar(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, go): def backward(ctx, go):
dim = ctx.dim dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
...@@ -144,7 +151,6 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -144,7 +151,6 @@ class distributed_transpose_polar(torch.autograd.Function):
# we need those additional primitives for distributed matrix multiplications # we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None): def _reduce(input_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1: if dist.get_world_size(group=group) == 1:
...@@ -165,7 +171,6 @@ def _reduce(input_, use_fp32=True, group=None): ...@@ -165,7 +171,6 @@ def _reduce(input_, use_fp32=True, group=None):
def _split(input_, dim_, group=None): def _split(input_, dim_, group=None):
"""Split the tensor along its last dimension and keep the corresponding slice."""
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
comm_size = dist.get_world_size(group=group) comm_size = dist.get_world_size(group=group)
if comm_size == 1: if comm_size == 1:
...@@ -182,7 +187,6 @@ def _split(input_, dim_, group=None): ...@@ -182,7 +187,6 @@ def _split(input_, dim_, group=None):
def _gather(input_, dim_, shapes_, group=None): def _gather(input_, dim_, shapes_, group=None):
"""Gather unevenly split tensors across ranks"""
comm_size = dist.get_world_size(group=group) comm_size = dist.get_world_size(group=group)
...@@ -215,7 +219,6 @@ def _gather(input_, dim_, shapes_, group=None): ...@@ -215,7 +219,6 @@ def _gather(input_, dim_, shapes_, group=None):
def _reduce_scatter(input_, dim_, use_fp32=True, group=None): def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group and scatter it back."""
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1: if dist.get_world_size(group=group) == 1:
...@@ -244,7 +247,6 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None): ...@@ -244,7 +247,6 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
class _CopyToPolarRegion(torch.autograd.Function): class _CopyToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
...@@ -253,11 +255,13 @@ class _CopyToPolarRegion(torch.autograd.Function): ...@@ -253,11 +255,13 @@ class _CopyToPolarRegion(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
def forward(ctx, input_): def forward(ctx, input_):
return input_ return input_
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
if is_distributed_polar(): if is_distributed_polar():
return _reduce(grad_output, group=polar_group()) return _reduce(grad_output, group=polar_group())
else: else:
...@@ -265,7 +269,6 @@ class _CopyToPolarRegion(torch.autograd.Function): ...@@ -265,7 +269,6 @@ class _CopyToPolarRegion(torch.autograd.Function):
class _CopyToAzimuthRegion(torch.autograd.Function): class _CopyToAzimuthRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
...@@ -274,11 +277,13 @@ class _CopyToAzimuthRegion(torch.autograd.Function): ...@@ -274,11 +277,13 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
def forward(ctx, input_): def forward(ctx, input_):
return input_ return input_
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
if is_distributed_azimuth(): if is_distributed_azimuth():
return _reduce(grad_output, group=azimuth_group()) return _reduce(grad_output, group=azimuth_group())
else: else:
...@@ -286,7 +291,6 @@ class _CopyToAzimuthRegion(torch.autograd.Function): ...@@ -286,7 +291,6 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
class _ScatterToPolarRegion(torch.autograd.Function): class _ScatterToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_): def symbolic(graph, input_, dim_):
...@@ -314,7 +318,6 @@ class _ScatterToPolarRegion(torch.autograd.Function): ...@@ -314,7 +318,6 @@ class _ScatterToPolarRegion(torch.autograd.Function):
class _GatherFromPolarRegion(torch.autograd.Function): class _GatherFromPolarRegion(torch.autograd.Function):
"""Gather the input and keep it on the rank."""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_, shapes_): def symbolic(graph, input_, dim_, shapes_):
...@@ -339,7 +342,6 @@ class _GatherFromPolarRegion(torch.autograd.Function): ...@@ -339,7 +342,6 @@ class _GatherFromPolarRegion(torch.autograd.Function):
class _ReduceFromPolarRegion(torch.autograd.Function): class _ReduceFromPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
...@@ -363,7 +365,6 @@ class _ReduceFromPolarRegion(torch.autograd.Function): ...@@ -363,7 +365,6 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
class _ReduceFromAzimuthRegion(torch.autograd.Function): class _ReduceFromAzimuthRegion(torch.autograd.Function):
"""All-reduce the input from the azimuth region."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
...@@ -387,7 +388,6 @@ class _ReduceFromAzimuthRegion(torch.autograd.Function): ...@@ -387,7 +388,6 @@ class _ReduceFromAzimuthRegion(torch.autograd.Function):
class _ReduceFromScatterToPolarRegion(torch.autograd.Function): class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region and scatter back to polar region."""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_): def symbolic(graph, input_, dim_):
...@@ -418,7 +418,6 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function): ...@@ -418,7 +418,6 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
class _GatherFromCopyToPolarRegion(torch.autograd.Function): class _GatherFromCopyToPolarRegion(torch.autograd.Function):
"""Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter"""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_, shapes_): def symbolic(graph, input_, dim_, shapes_):
......
...@@ -55,6 +55,27 @@ def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False, ...@@ -55,6 +55,27 @@ def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False,
class DiceLossS2(nn.Module): class DiceLossS2(nn.Module):
"""
Dice loss for spherical segmentation tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
mode : str, optional
Aggregation mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100, mode: str = "micro"): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100, mode: str = "micro"):
super().__init__() super().__init__()
...@@ -113,6 +134,24 @@ class DiceLossS2(nn.Module): ...@@ -113,6 +134,24 @@ class DiceLossS2(nn.Module):
class CrossEntropyLossS2(nn.Module): class CrossEntropyLossS2(nn.Module):
"""
Cross-entropy loss for spherical classification tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Label smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100):
...@@ -141,6 +180,24 @@ class CrossEntropyLossS2(nn.Module): ...@@ -141,6 +180,24 @@ class CrossEntropyLossS2(nn.Module):
class FocalLossS2(nn.Module): class FocalLossS2(nn.Module):
"""
Focal loss for spherical classification tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Label smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100):
...@@ -286,10 +343,19 @@ class NormalLossS2(SphericalLossBase): ...@@ -286,10 +343,19 @@ class NormalLossS2(SphericalLossBase):
Surface normals are computed by calculating gradients in latitude and longitude Surface normals are computed by calculating gradients in latitude and longitude
directions using FFT, then constructing 3D normal vectors that are normalized. directions using FFT, then constructing 3D normal vectors that are normalized.
Args: Parameters
nlat (int): Number of latitude points ----------
nlon (int): Number of longitude points nlat : int
grid (str, optional): Grid type. Defaults to "equiangular". Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
Returns
-------
torch.Tensor
Combined loss term
""" """
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
......
...@@ -49,6 +49,31 @@ def _get_stats_multiclass( ...@@ -49,6 +49,31 @@ def _get_stats_multiclass(
quad_weights: torch.Tensor, quad_weights: torch.Tensor,
ignore_index: Optional[int], ignore_index: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute multiclass statistics (TP, FP, FN, TN) on the sphere using quadrature weights.
This function computes true positives, false positives, false negatives, and true negatives
for multiclass classification on spherical data, properly weighted by quadrature weights
to account for the spherical geometry.
Parameters
-----------
output : torch.LongTensor
Predicted class labels
target : torch.LongTensor
Ground truth class labels
num_classes : int
Number of classes in the classification task
quad_weights : torch.Tensor
Quadrature weights for spherical integration
ignore_index : Optional[int]
Index to ignore in the computation (e.g., for padding or invalid regions)
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple containing (tp_count, fp_count, fn_count, tn_count) for each class
"""
batch_size, *dims = output.shape batch_size, *dims = output.shape
num_elements = torch.prod(torch.tensor(dims)).long() num_elements = torch.prod(torch.tensor(dims)).long()
...@@ -88,10 +113,46 @@ def _get_stats_multiclass( ...@@ -88,10 +113,46 @@ def _get_stats_multiclass(
def _predict_classes(logits: torch.Tensor) -> torch.Tensor: def _predict_classes(logits: torch.Tensor) -> torch.Tensor:
"""
Convert logits to class predictions using softmax and argmax.
Parameters
-----------
logits : torch.Tensor
Input logits tensor
Returns
-------
torch.Tensor
Predicted class labels
"""
return torch.argmax(torch.softmax(logits, dim=1), dim=1, keepdim=False) return torch.argmax(torch.softmax(logits, dim=1), dim=1, keepdim=False)
class BaseMetricS2(nn.Module): class BaseMetricS2(nn.Module):
"""
Base class for spherical metrics that properly handle spherical geometry.
This class provides the foundation for computing metrics on spherical data
by using quadrature weights to account for the non-uniform area distribution
on the sphere.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__() super().__init__()
...@@ -108,6 +169,7 @@ class BaseMetricS2(nn.Module): ...@@ -108,6 +169,7 @@ class BaseMetricS2(nn.Module):
self.register_buffer("weight", weight.unsqueeze(0)) self.register_buffer("weight", weight.unsqueeze(0))
def _forward(self, pred: torch.Tensor, truth: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: def _forward(self, pred: torch.Tensor, truth: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# convert logits to class predictions # convert logits to class predictions
pred_class = _predict_classes(pred) pred_class = _predict_classes(pred)
...@@ -138,6 +200,28 @@ class BaseMetricS2(nn.Module): ...@@ -138,6 +200,28 @@ class BaseMetricS2(nn.Module):
class IntersectionOverUnionS2(BaseMetricS2): class IntersectionOverUnionS2(BaseMetricS2):
"""
Intersection over Union (IoU) metric for spherical data.
Computes the IoU score for multiclass classification on the sphere,
properly weighted by quadrature weights to account for spherical geometry.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode) super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
...@@ -162,6 +246,28 @@ class IntersectionOverUnionS2(BaseMetricS2): ...@@ -162,6 +246,28 @@ class IntersectionOverUnionS2(BaseMetricS2):
class AccuracyS2(BaseMetricS2): class AccuracyS2(BaseMetricS2):
"""
Accuracy metric for spherical data.
Computes the accuracy score for multiclass classification on the sphere,
properly weighted by quadrature weights to account for spherical geometry.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode) super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
......
...@@ -41,9 +41,11 @@ from torch_harmonics import InverseRealSHT ...@@ -41,9 +41,11 @@ from torch_harmonics import InverseRealSHT
def _no_grad_trunc_normal_(tensor, mean, std, a, b): def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x): def norm_cdf(x):
# Computes standard normal cumulative distribution function # Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
...@@ -75,19 +77,27 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): ...@@ -75,19 +77,27 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated """Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`. best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor` Parameters
mean: the mean of the normal distribution -----------
std: the standard deviation of the normal distribution tensor: torch.Tensor
a: the minimum cutoff value an n-dimensional `torch.Tensor`
b: the maximum cutoff value mean: float
Examples: the mean of the normal distribution
std: float
the standard deviation of the normal distribution
a: float
the minimum cutoff value, by default -2.0
b: float
the maximum cutoff value
Examples
--------
>>> w = torch.empty(3, 5) >>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w) >>> nn.init.trunc_normal_(w)
""" """
...@@ -102,6 +112,20 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) - ...@@ -102,6 +112,20 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument. 'survival rate' as the argument.
Parameters
----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Probability of dropping a path, by default 0.0
training : bool, optional
Whether the model is in training mode, by default False
Returns
-------
torch.Tensor
Output tensor
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return x
...@@ -114,17 +138,47 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) - ...@@ -114,17 +138,47 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
class DropPath(nn.Module): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This module implements stochastic depth regularization by randomly dropping
entire residual paths during training, which helps with regularization and
training of very deep networks.
Parameters
----------
drop_prob : float, optional
Probability of dropping a path, by default None
"""
def __init__(self, drop_prob=None): def __init__(self, drop_prob=None):
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x):
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
"""
Patch embedding layer for vision transformers.
This module splits input images into patches and projects them to a
higher dimensional embedding space using convolutional layers.
Parameters
----------
img_size : tuple, optional
Input image size (height, width), by default (224, 224)
patch_size : tuple, optional
Patch size (height, width), by default (16, 16)
in_chans : int, optional
Number of input channels, by default 3
embed_dim : int, optional
Embedding dimension, by default 768
"""
def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768): def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
super(PatchEmbed, self).__init__() super(PatchEmbed, self).__init__()
self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1])) self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1]))
...@@ -137,6 +191,7 @@ class PatchEmbed(nn.Module): ...@@ -137,6 +191,7 @@ class PatchEmbed(nn.Module):
self.proj.bias.is_shared_mp = ["spatial"] self.proj.bias.is_shared_mp = ["spatial"]
def forward(self, x): def forward(self, x):
# gather input # gather input
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
...@@ -146,6 +201,32 @@ class PatchEmbed(nn.Module): ...@@ -146,6 +201,32 @@ class PatchEmbed(nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
"""
Multi-layer perceptron with optional checkpointing.
This module implements a feed-forward network with two linear layers
and an activation function, with optional dropout and gradient checkpointing.
Parameters
----------
in_features : int
Number of input features
hidden_features : int, optional
Number of hidden features, by default None (same as in_features)
out_features : int, optional
Number of output features, by default None (same as in_features)
act_layer : nn.Module, optional
Activation layer, by default nn.ReLU
output_bias : bool, optional
Whether to use bias in output layer, by default False
drop_rate : float, optional
Dropout rate, by default 0.0
checkpointing : bool, optional
Whether to use gradient checkpointing, by default False
gain : float, optional
Gain factor for weight initialization, by default 1.0
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
super(MLP, self).__init__() super(MLP, self).__init__()
self.checkpointing = checkpointing self.checkpointing = checkpointing
...@@ -179,9 +260,11 @@ class MLP(nn.Module): ...@@ -179,9 +260,11 @@ class MLP(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def checkpoint_forward(self, x): def checkpoint_forward(self, x):
return checkpoint(self.fwd, x) return checkpoint(self.fwd, x)
def forward(self, x): def forward(self, x):
if self.checkpointing: if self.checkpointing:
return self.checkpoint_forward(x) return self.checkpoint_forward(x)
else: else:
...@@ -190,7 +273,21 @@ class MLP(nn.Module): ...@@ -190,7 +273,21 @@ class MLP(nn.Module):
class RealFFT2(nn.Module): class RealFFT2(nn.Module):
""" """
Helper routine to wrap FFT similarly to the SHT Helper routine to wrap FFT similarly to the SHT.
This module provides a wrapper around PyTorch's real FFT2D that mimics
the interface of spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
""" """
def __init__(self, nlat, nlon, lmax=None, mmax=None): def __init__(self, nlat, nlon, lmax=None, mmax=None):
...@@ -202,6 +299,7 @@ class RealFFT2(nn.Module): ...@@ -202,6 +299,7 @@ class RealFFT2(nn.Module):
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
def forward(self, x): def forward(self, x):
y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho") y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2) y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
return y return y
...@@ -209,7 +307,21 @@ class RealFFT2(nn.Module): ...@@ -209,7 +307,21 @@ class RealFFT2(nn.Module):
class InverseRealFFT2(nn.Module): class InverseRealFFT2(nn.Module):
""" """
Helper routine to wrap FFT similarly to the SHT Helper routine to wrap inverse FFT similarly to the SHT.
This module provides a wrapper around PyTorch's inverse real FFT2D that mimics
the interface of inverse spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
""" """
def __init__(self, nlat, nlon, lmax=None, mmax=None): def __init__(self, nlat, nlon, lmax=None, mmax=None):
...@@ -221,12 +333,32 @@ class InverseRealFFT2(nn.Module): ...@@ -221,12 +333,32 @@ class InverseRealFFT2(nn.Module):
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
def forward(self, x): def forward(self, x):
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho") return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
""" """
Wrapper class that moves the channel dimension to the end Wrapper class that moves the channel dimension to the end.
This module provides a layer normalization that works with channel-first
tensors by temporarily transposing the channel dimension to the end,
applying normalization, and then transposing back.
Parameters
----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type for the module, by default None
""" """
def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None): def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
...@@ -246,6 +378,25 @@ class SpectralConvS2(nn.Module): ...@@ -246,6 +378,25 @@ class SpectralConvS2(nn.Module):
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2 Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers. domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters
----------
forward_transform : nn.Module
Forward transform (SHT or FFT)
inverse_transform : nn.Module
Inverse transform (ISHT or IFFT)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
gain : float, optional
Gain factor for weight initialization, by default 2.0
operator_type : str, optional
Type of spectral operator ("driscoll-healy", "diagonal", "block-diagonal"), by default "driscoll-healy"
lr_scale_exponent : int, optional
Learning rate scaling exponent, by default 0
bias : bool, optional
Whether to use bias, by default False
""" """
def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False): def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
...@@ -307,13 +458,25 @@ class SpectralConvS2(nn.Module): ...@@ -307,13 +458,25 @@ class SpectralConvS2(nn.Module):
return x, residual return x, residual
class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
""" """
Returns standard sequence based position embedding Abstract base class for position embeddings.
This class defines the interface for position embedding modules
that add positional information to input tensors.
Parameters
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__() super().__init__()
self.img_shape = img_shape self.img_shape = img_shape
...@@ -323,17 +486,28 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): ...@@ -323,17 +486,28 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
return x + self.position_embeddings return x + self.position_embeddings
class SequencePositionEmbedding(PositionEmbedding): class SequencePositionEmbedding(PositionEmbedding):
""" """
Returns standard sequence based position embedding Standard sequence-based position embedding.
This module implements sinusoidal position embeddings similar to those
used in the original Transformer paper, adapted for 2D spatial data.
Parameters
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
with torch.no_grad(): with torch.no_grad():
# alternating custom position embeddings # alternating custom position embeddings
pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1) pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1) k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
...@@ -344,13 +518,26 @@ class SequencePositionEmbedding(PositionEmbedding): ...@@ -344,13 +518,26 @@ class SequencePositionEmbedding(PositionEmbedding):
# register tensor # register tensor
self.register_buffer("position_embeddings", pos_embed.float()) self.register_buffer("position_embeddings", pos_embed.float())
class SpectralPositionEmbedding(PositionEmbedding): class SpectralPositionEmbedding(PositionEmbedding):
""" """
Returns position embeddings for the spherical transformer Spectral position embeddings for spherical transformers.
This module creates position embeddings in the spectral domain using
spherical harmonic functions, which are particularly suitable for
spherical data processing.
Parameters
-----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# compute maximum required frequency and prepare isht # compute maximum required frequency and prepare isht
...@@ -382,11 +569,24 @@ class SpectralPositionEmbedding(PositionEmbedding): ...@@ -382,11 +569,24 @@ class SpectralPositionEmbedding(PositionEmbedding):
class LearnablePositionEmbedding(PositionEmbedding): class LearnablePositionEmbedding(PositionEmbedding):
""" """
Returns position embeddings for the spherical transformer Learnable position embeddings for spherical transformers.
This module provides learnable position embeddings that can be either
latitude-only or full latitude-longitude embeddings.
Parameters
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
embed_type : str, optional
Embedding type ("lat" or "latlon"), by default "lat"
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
if embed_type == "latlon": if embed_type == "latlon":
......
...@@ -50,6 +50,35 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type): ...@@ -50,6 +50,35 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1) return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module): class DiscreteContinuousEncoder(nn.Module):
"""
Discrete-continuous encoder for spherical neural operators.
This module performs downsampling using discrete-continuous convolutions on the sphere,
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
inp_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
"""
def __init__( def __init__(
self, self,
in_shape=(721, 1440), in_shape=(721, 1440),
...@@ -81,6 +110,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -81,6 +110,7 @@ class DiscreteContinuousEncoder(nn.Module):
) )
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False): with amp.autocast(device_type="cuda", enabled=False):
...@@ -92,6 +122,37 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -92,6 +122,37 @@ class DiscreteContinuousEncoder(nn.Module):
class DiscreteContinuousDecoder(nn.Module): class DiscreteContinuousDecoder(nn.Module):
"""
Discrete-continuous decoder for spherical neural operators.
This module performs upsampling using either spherical harmonic transforms or resampling,
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (721, 1440)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
inp_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsample_sht : bool, optional
Whether to use SHT for upsampling, by default False
"""
def __init__( def __init__(
self, self,
in_shape=(480, 960), in_shape=(480, 960),
...@@ -132,6 +193,7 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -132,6 +193,7 @@ class DiscreteContinuousDecoder(nn.Module):
) )
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False): with amp.autocast(device_type="cuda", enabled=False):
...@@ -146,6 +208,46 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -146,6 +208,46 @@ class DiscreteContinuousDecoder(nn.Module):
class SphericalNeuralOperatorBlock(nn.Module): class SphericalNeuralOperatorBlock(nn.Module):
""" """
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks. Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
conv_type : str, optional
Type of convolution to use, by default "local"
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
disco_kernel_shape : tuple, optional
Kernel shape for discrete-continuous convolution, by default (3, 3)
disco_basis_type : str, optional
Filter basis type for discrete-continuous convolution, by default "morlet"
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
""" """
def __init__( def __init__(
...@@ -281,51 +383,56 @@ class LocalSphericalNeuralOperator(nn.Module): ...@@ -281,51 +383,56 @@ class LocalSphericalNeuralOperator(nn.Module):
as well as in the encoder and decoders. as well as in the encoder and decoders.
Parameters Parameters
----------- ----------
img_shape : tuple, optional img_size : tuple, optional
Shape of the input channels, by default (128, 256) Input image size (nlat, nlon), by default (128, 256)
kernel_shape: tuple, int grid : str, optional
Grid type for input/output, by default "equiangular"
grid_internal : str, optional
Grid type for internal processing, by default "legendre-gauss"
scale_factor : int, optional scale_factor : int, optional
Scale factor to use, by default 3 Scale factor for resolution changes, by default 3
in_chans : int, optional in_chans : int, optional
Number of input channels, by default 3 Number of input channels, by default 3
out_chans : int, optional out_chans : int, optional
Number of output channels, by default 3 Number of output channels, by default 3
embed_dim : int, optional embed_dim : int, optional
Dimension of the embeddings, by default 256 Embedding dimension, by default 256
num_layers : int, optional num_layers : int, optional
Number of layers in the network, by default 4 Number of layers, by default 4
activation_function : str, optional activation_function : str, optional
Activation function to use, by default "gelu" Activation function name, by default "gelu"
encoder_kernel_shape : int, optional kernel_shape : tuple, optional
size of the encoder kernel Kernel shape for convolutions, by default (3, 3)
filter_basis_type: Optional[str]: str, optional encoder_kernel_shape : tuple, optional
filter basis type Kernel shape for encoder, by default (3, 3)
use_mlp : int, optional filter_basis_type : str, optional
Whether to use MLPs in the SFNO blocks, by default True Filter basis type, by default "morlet"
mlp_ratio : int, optional use_mlp : bool, optional
Ratio of MLP to use, by default 2.0 Whether to use MLP layers, by default True
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional drop_rate : float, optional
Dropout rate, by default 0.0 Dropout rate, by default 0.0
drop_path_rate : float, optional drop_path_rate : float, optional
Dropout path rate, by default 0.0 Drop path rate, by default 0.0
normalization_layer : str, optional normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
sfno_block_frequency : int, optional sfno_block_frequency : int, optional
Hopw often a (global) SFNO block is used, by default 2 Frequency of SFNO blocks, by default 2
hard_thresholding_fraction : float, optional hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0 Hard thresholding fraction, by default 1.0
big_skip : bool, optional residual_prediction : bool, optional
Whether to add a single large skip connection, by default True Whether to use residual prediction, by default False
pos_embed : bool, optional pos_embed : str, optional
Whether to use positional embedding, by default True Position embedding type, by default "none"
upsample_sht : bool, optional upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation Use SHT upsampling if true, else linear interpolation
bias : bool, optional bias : bool, optional
Whether to use a bias, by default False Whether to use a bias, by default False
Example Example
----------- ----------
>>> model = LocalSphericalNeuralOperator( >>> model = LocalSphericalNeuralOperator(
... img_shape=(128, 256), ... img_shape=(128, 256),
... scale_factor=4, ... scale_factor=4,
...@@ -338,7 +445,7 @@ class LocalSphericalNeuralOperator(nn.Module): ...@@ -338,7 +445,7 @@ class LocalSphericalNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256]) torch.Size([1, 2, 128, 256])
References References
----------- ----------
.. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.; .. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
"Neural Operators with Localized Integral and Differential Kernels" (2024). "Neural Operators with Localized Integral and Differential Kernels" (2024).
ICML 2024, https://arxiv.org/pdf/2402.16845. ICML 2024, https://arxiv.org/pdf/2402.16845.
...@@ -497,6 +604,7 @@ class LocalSphericalNeuralOperator(nn.Module): ...@@ -497,6 +604,7 @@ class LocalSphericalNeuralOperator(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
if self.residual_prediction: if self.residual_prediction:
residual = x residual = x
......
...@@ -54,6 +54,34 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type): ...@@ -54,6 +54,34 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
class OverlapPatchMerging(nn.Module): class OverlapPatchMerging(nn.Module):
"""
Overlap patch merging module for spherical segformer.
This module performs patch merging with overlapping patches using discrete-continuous
convolutions on the sphere, followed by layer normalization.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (481, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_channels : int, optional
Number of input channels, by default 3
out_channels : int, optional
Number of output channels, by default 64
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
bias : bool, optional
Whether to use bias, by default False
"""
def __init__( def __init__(
self, self,
in_shape=(721, 1440), in_shape=(721, 1440),
...@@ -89,11 +117,13 @@ class OverlapPatchMerging(nn.Module): ...@@ -89,11 +117,13 @@ class OverlapPatchMerging(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.LayerNorm): if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False): with amp.autocast(device_type="cuda", enabled=False):
...@@ -109,6 +139,38 @@ class OverlapPatchMerging(nn.Module): ...@@ -109,6 +139,38 @@ class OverlapPatchMerging(nn.Module):
class MixFFN(nn.Module): class MixFFN(nn.Module):
"""
Mix FFN module for spherical segformer.
This module implements a feed-forward network that combines MLP operations
with discrete-continuous convolutions on the sphere.
Parameters
-----------
shape : tuple
Shape (nlat, nlon) of the input
inout_channels : int
Number of input/output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
grid : str, optional
Grid type, by default "equiangular"
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : nn.Module, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
drop_path : float, optional
Drop path rate, by default 0.0
"""
def __init__( def __init__(
self, self,
shape, shape,
...@@ -161,6 +223,7 @@ class MixFFN(nn.Module): ...@@ -161,6 +223,7 @@ class MixFFN(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: if m.bias is not None:
...@@ -170,7 +233,6 @@ class MixFFN(nn.Module): ...@@ -170,7 +233,6 @@ class MixFFN(nn.Module):
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x residual = x
# norm # norm
...@@ -194,6 +256,35 @@ class MixFFN(nn.Module): ...@@ -194,6 +256,35 @@ class MixFFN(nn.Module):
class AttentionWrapper(nn.Module): class AttentionWrapper(nn.Module):
"""
Attention wrapper for spherical segformer.
This module wraps attention mechanisms (neighborhood or global) with optional
normalization and drop path regularization.
Parameters
-----------
channels : int
Number of channels
shape : tuple
Shape (nlat, nlon) of the input
grid : str
Grid type
heads : int
Number of attention heads
pre_norm : bool, optional
Whether to apply normalization before attention, by default False
attention_drop_rate : float, optional
Dropout rate for attention, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
bias : bool, optional
Whether to use bias, by default True
"""
def __init__( def __init__(
self, self,
channels, channels,
...@@ -271,6 +362,49 @@ class AttentionWrapper(nn.Module): ...@@ -271,6 +362,49 @@ class AttentionWrapper(nn.Module):
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
"""
Transformer block for spherical segformer.
This block combines patch merging, attention, and Mix FFN operations
in a hierarchical structure for processing spherical data.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of repetitions, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.GELU
att_drop_rate : float, optional
Dropout rate for attention, by default 0.0
drop_path_rates : float, optional
Drop path rates, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
bias : bool, optional
Whether to use bias, by default True
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -374,6 +508,43 @@ class TransformerBlock(nn.Module): ...@@ -374,6 +508,43 @@ class TransformerBlock(nn.Module):
class Upsampling(nn.Module): class Upsampling(nn.Module):
"""
Upsampling module for spherical segformer.
This module performs upsampling using either discrete-continuous transposed convolutions
or bilinear resampling on spherical data.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : nn.Module, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
upsampling_method : str, optional
Upsampling method ("conv" or "bilinear"), by default "conv"
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -429,6 +600,7 @@ class Upsampling(nn.Module): ...@@ -429,6 +600,7 @@ class Upsampling(nn.Module):
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample(self.mlp(x)) x = self.upsample(self.mlp(x))
return x return x
...@@ -606,6 +778,7 @@ class SphericalSegformer(nn.Module): ...@@ -606,6 +778,7 @@ class SphericalSegformer(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: if m.bias is not None:
......
...@@ -52,6 +52,36 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type): ...@@ -52,6 +52,36 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1) return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module): class DiscreteContinuousEncoder(nn.Module):
"""
Discrete-continuous encoder for spherical transformers.
This module performs downsampling using discrete-continuous convolutions on the sphere,
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
"""
def __init__( def __init__(
self, self,
in_shape=(721, 1440), in_shape=(721, 1440),
...@@ -83,6 +113,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -83,6 +113,7 @@ class DiscreteContinuousEncoder(nn.Module):
) )
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False): with amp.autocast(device_type="cuda", enabled=False):
...@@ -94,6 +125,38 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -94,6 +125,38 @@ class DiscreteContinuousEncoder(nn.Module):
class DiscreteContinuousDecoder(nn.Module): class DiscreteContinuousDecoder(nn.Module):
"""
Discrete-continuous decoder for spherical transformers.
This module performs upsampling using either spherical harmonic transforms or resampling,
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (721, 1440)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsample_sht : bool, optional
Whether to use SHT for upsampling, by default False
"""
def __init__( def __init__(
self, self,
in_shape=(480, 960), in_shape=(480, 960),
...@@ -134,6 +197,7 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -134,6 +197,7 @@ class DiscreteContinuousDecoder(nn.Module):
) )
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False): with amp.autocast(device_type="cuda", enabled=False):
...@@ -147,7 +211,45 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -147,7 +211,45 @@ class DiscreteContinuousDecoder(nn.Module):
class SphericalAttentionBlock(nn.Module): class SphericalAttentionBlock(nn.Module):
""" """
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks. Spherical attention block for transformers on the sphere.
This module implements a single attention block that can use either global attention
or neighborhood attention on spherical data, followed by an optional MLP.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
num_heads : int, optional
Number of attention heads, by default 1
mlp_ratio : float, optional
Ratio of MLP hidden dimension to output dimension, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : nn.Module, optional
Activation layer, by default nn.GELU
norm_layer : str, optional
Normalization layer type, by default "none"
use_mlp : bool, optional
Whether to use MLP after attention, by default True
bias : bool, optional
Whether to use bias, by default False
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
""" """
def __init__( def __init__(
...@@ -467,6 +569,7 @@ class SphericalTransformer(nn.Module): ...@@ -467,6 +569,7 @@ class SphericalTransformer(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
if self.residual_prediction: if self.residual_prediction:
residual = x residual = x
......
...@@ -54,6 +54,46 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type): ...@@ -54,6 +54,46 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
class DownsamplingBlock(nn.Module): class DownsamplingBlock(nn.Module):
"""
Downsampling block for spherical U-Net architecture.
This block performs convolution operations followed by downsampling on spherical data,
using discrete-continuous convolutions to maintain spectral properties.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of convolution repetitions, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connection, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.0
drop_path_rate : float, optional
Drop path rate, by default 0.0
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.0
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -154,12 +194,14 @@ class DownsamplingBlock(nn.Module): ...@@ -154,12 +194,14 @@ class DownsamplingBlock(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# skip connection # skip connection
residual = x residual = x
if hasattr(self, "transform_skip"): if hasattr(self, "transform_skip"):
...@@ -178,6 +220,46 @@ class DownsamplingBlock(nn.Module): ...@@ -178,6 +220,46 @@ class DownsamplingBlock(nn.Module):
class UpsamplingBlock(nn.Module): class UpsamplingBlock(nn.Module):
"""
Upsampling block for spherical U-Net architecture.
This block performs upsampling followed by convolution operations on spherical data,
using discrete-continuous convolutions to maintain spectral properties.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of convolution repetitions, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connection, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.0
drop_path_rate : float, optional
Drop path rate, by default 0.0
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.0
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -496,6 +578,7 @@ class SphericalUNet(nn.Module): ...@@ -496,6 +578,7 @@ class SphericalUNet(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: if m.bias is not None:
......
...@@ -43,6 +43,40 @@ from functools import partial ...@@ -43,6 +43,40 @@ from functools import partial
class SphericalFourierNeuralOperatorBlock(nn.Module): class SphericalFourierNeuralOperatorBlock(nn.Module):
""" """
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks. Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
""" """
def __init__( def __init__(
...@@ -118,7 +152,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -118,7 +152,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else: else:
raise ValueError(f"Unknown skip connection type {outer_skip}") raise ValueError(f"Unknown skip connection type {outer_skip}")
def forward(self, x): def forward(self, x):
x, residual = self.global_conv(x) x, residual = self.global_conv(x)
...@@ -147,8 +180,12 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -147,8 +180,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Parameters Parameters
---------- ----------
img_shape : tuple, optional img_size : tuple, optional
Shape of the input channels, by default (128, 256) Shape of the input channels, by default (128, 256)
grid : str, optional
Input grid type, by default "equiangular"
grid_internal : str, optional
Internal grid type for computations, by default "legendre-gauss"
scale_factor : int, optional scale_factor : int, optional
Scale factor to use, by default 3 Scale factor to use, by default 3
in_chans : int, optional in_chans : int, optional
...@@ -172,20 +209,20 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -172,20 +209,20 @@ class SphericalFourierNeuralOperator(nn.Module):
drop_path_rate : float, optional drop_path_rate : float, optional
Dropout path rate, by default 0.0 Dropout path rate, by default 0.0
normalization_layer : str, optional normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "none"
hard_thresholding_fraction : float, optional hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0 Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
residual_prediction : bool, optional residual_prediction : bool, optional
Whether to add a single large skip connection, by default True Whether to add a single large skip connection, by default False
pos_embed : bool, optional pos_embed : str, optional
Whether to use positional embedding, by default True Type of positional embedding to use, by default "none"
bias : bool, optional bias : bool, optional
Whether to use a bias, by default False Whether to use a bias, by default False
Example: Example:
-------- ----------
>>> model = SphericalFourierNeuralOperator( >>> model = SphericalFourierNeuralOperator(
... img_shape=(128, 256), ... img_size=(128, 256),
... scale_factor=4, ... scale_factor=4,
... in_chans=2, ... in_chans=2,
... out_chans=2, ... out_chans=2,
...@@ -196,7 +233,7 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -196,7 +233,7 @@ class SphericalFourierNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256]) torch.Size([1, 2, 128, 256])
References References
----------- ----------
.. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.; .. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023). "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
ICML 2023, https://arxiv.org/abs/2306.03838. ICML 2023, https://arxiv.org/abs/2306.03838.
......
...@@ -37,7 +37,38 @@ from .shallow_water_equations import ShallowWaterSolver ...@@ -37,7 +37,38 @@ from .shallow_water_equations import ShallowWaterSolver
class PdeDataset(torch.utils.data.Dataset): class PdeDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data""" """Custom Dataset class for PDE training data
Parameters
----------
dt : float
Time step
nsteps : int
Number of solver steps
dims : tuple, optional
Number of latitude and longitude points, by default (384, 768)
grid : str, optional
Grid type, by default "equiangular"
pde : str, optional
PDE type, by default "shallow water equations"
initial_condition : str, optional
Initial condition type, by default "random"
num_examples : int, optional
Number of examples, by default 32
device : torch.device, optional
Device to use, by default torch.device("cpu")
normalize : bool, optional
Whether to normalize the input and target, by default True
stream : torch.cuda.Stream, optional
CUDA stream to use, by default None
Returns
-------
inp : torch.Tensor
Input tensor
tar : torch.Tensor
Target tensor
"""
def __init__( def __init__(
self, self,
......
...@@ -42,7 +42,27 @@ import numpy as np ...@@ -42,7 +42,27 @@ import numpy as np
class SphereSolver(nn.Module): class SphereSolver(nn.Module):
""" """
Solver class on the sphere. Can solve the following PDEs: Solver class on the sphere. Can solve the following PDEs:
- Allen-Cahn eq - Allen-Cahn equation
- Ginzburg-Landau equation
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
dt : float
Time step size
lmax : int, optional
Maximum l mode for spherical harmonics, by default None
mmax : int, optional
Maximum m mode for spherical harmonics, by default None
grid : str, optional
Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
radius : float, optional
Radius of the sphere, by default 1.0
coeff : float, optional
Coefficient for the PDE, by default 0.001
""" """
def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=1.0, coeff=0.001): def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=1.0, coeff=0.001):
...@@ -97,16 +117,14 @@ class SphereSolver(nn.Module): ...@@ -97,16 +117,14 @@ class SphereSolver(nn.Module):
self.register_buffer('invlap', invlap) self.register_buffer('invlap', invlap)
def grid2spec(self, u): def grid2spec(self, u):
"""spectral coefficients from spatial data"""
return self.sht(u) return self.sht(u)
def spec2grid(self, uspec): def spec2grid(self, uspec):
"""spatial data from spectral coefficients""" """Convert spectral coefficients to spatial data."""
return self.isht(uspec) return self.isht(uspec)
def dudtspec(self, uspec, pde='allen-cahn'): def dudtspec(self, uspec, pde='allen-cahn'):
"""Compute the time derivative of spectral coefficients for different PDEs."""
if pde == 'allen-cahn': if pde == 'allen-cahn':
ugrid = self.spec2grid(uspec) ugrid = self.spec2grid(uspec)
...@@ -117,20 +135,48 @@ class SphereSolver(nn.Module): ...@@ -117,20 +135,48 @@ class SphereSolver(nn.Module):
u3spec = self.grid2spec(ugrid**3) u3spec = self.grid2spec(ugrid**3)
dudtspec = uspec + (1. + 2.j)*self.coeff*self.lap*uspec - (1. + 2.j)*u3spec dudtspec = uspec + (1. + 2.j)*self.coeff*self.lap*uspec - (1. + 2.j)*u3spec
else: else:
NotImplementedError raise NotImplementedError(f"PDE type {pde} not implemented")
return dudtspec return dudtspec
def randspec(self): def randspec(self):
"""random data on the sphere""" """Generate random spectral data on the sphere."""
rspec = torch.randn_like(self.lap) / 4 / torch.pi rspec = torch.randn_like(self.lap) / 4 / torch.pi
return rspec return rspec
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False): def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
""" """
plotting routine for data on the grid. Requires cartopy for 3d plots. Plot data on the sphere grid. Requires cartopy for 3d plots.
Parameters
-----------
data : torch.Tensor
Data to plot
fig : matplotlib.figure.Figure
Figure to plot on
cmap : str, optional
Colormap name, by default 'twilight_shifted'
vmax : float, optional
Maximum value for color scaling, by default None
vmin : float, optional
Minimum value for color scaling, by default None
projection : str, optional
Projection type ("mollweide", "3d"), by default "3d"
title : str, optional
Plot title, by default None
antialiased : bool, optional
Whether to use antialiasing, by default False
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
Raises
------
NotImplementedError
If projection type is not supported
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -172,9 +218,10 @@ class SphereSolver(nn.Module): ...@@ -172,9 +218,10 @@ class SphereSolver(nn.Module):
plt.title(title, y=1.05) plt.title(title, y=1.05)
else: else:
raise NotImplementedError raise NotImplementedError(f"Projection {projection} not implemented")
return im return im
def plot_specdata(self, data, fig, **kwargs): def plot_specdata(self, data, fig, **kwargs):
"""Plot spectral data by converting to spatial data first."""
return self.plot_griddata(self.isht(data), fig, **kwargs) return self.plot_griddata(self.isht(data), fig, **kwargs)
...@@ -41,7 +41,35 @@ import numpy as np ...@@ -41,7 +41,35 @@ import numpy as np
class ShallowWaterSolver(nn.Module): class ShallowWaterSolver(nn.Module):
""" """
SWE solver class. Interface inspired bu pyspharm and SHTns Shallow Water Equations (SWE) solver class for spherical geometry.
Interface inspired by pyspharm and SHTns. Solves the shallow water equations
on a rotating sphere using spectral methods.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
dt : float
Time step size
lmax : int, optional
Maximum l mode for spherical harmonics, by default None
mmax : int, optional
Maximum m mode for spherical harmonics, by default None
grid : str, optional
Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
radius : float, optional
Radius of the sphere in meters, by default 6.37122E6 (Earth radius)
omega : float, optional
Angular velocity of rotation in rad/s, by default 7.292E-5 (Earth)
gravity : float, optional
Gravitational acceleration in m/s², by default 9.80616
havg : float, optional
Average height in meters, by default 10.e3
hamp : float, optional
Height amplitude in meters, by default 120.
""" """
def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=6.37122E6, \ def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=6.37122E6, \
...@@ -114,58 +142,43 @@ class ShallowWaterSolver(nn.Module): ...@@ -114,58 +142,43 @@ class ShallowWaterSolver(nn.Module):
self.register_buffer('quad_weights', quad_weights) self.register_buffer('quad_weights', quad_weights)
def grid2spec(self, ugrid): def grid2spec(self, ugrid):
""" """Convert spatial data to spectral coefficients."""
spectral coefficients from spatial data
"""
return self.sht(ugrid) return self.sht(ugrid)
def spec2grid(self, uspec): def spec2grid(self, uspec):
""" """Convert spectral coefficients to spatial data."""
spatial data from spectral coefficients
"""
return self.isht(uspec) return self.isht(uspec)
def vrtdivspec(self, ugrid): def vrtdivspec(self, ugrid):
"""spatial data from spectral coefficients""" """Compute vorticity and divergence from velocity field."""
vrtdivspec = self.lap * self.radius * self.vsht(ugrid) vrtdivspec = self.lap * self.radius * self.vsht(ugrid)
return vrtdivspec return vrtdivspec
def getuv(self, vrtdivspec): def getuv(self, vrtdivspec):
""" """Compute wind vector from spectral coefficients of vorticity and divergence."""
compute wind vector from spectral coeffs of vorticity and divergence
"""
return self.ivsht( self.invlap * vrtdivspec / self.radius) return self.ivsht( self.invlap * vrtdivspec / self.radius)
def gethuv(self, uspec): def gethuv(self, uspec):
""" """Compute height and wind vector from spectral coefficients."""
compute wind vector from spectral coeffs of vorticity and divergence
"""
hgrid = self.spec2grid(uspec[:1]) hgrid = self.spec2grid(uspec[:1])
uvgrid = self.getuv(uspec[1:]) uvgrid = self.getuv(uspec[1:])
return torch.cat((hgrid, uvgrid), dim=-3) return torch.cat((hgrid, uvgrid), dim=-3)
def potential_vorticity(self, uspec): def potential_vorticity(self, uspec):
""" """Compute potential vorticity from spectral coefficients."""
Compute potential vorticity
"""
ugrid = self.spec2grid(uspec) ugrid = self.spec2grid(uspec)
pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0] pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0]
return pvrt return pvrt
def dimensionless(self, uspec): def dimensionless(self, uspec):
""" """Remove dimensions from variables for dimensionless analysis."""
Remove dimensions from variables
"""
uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r # vorticity is measured in 1/s so we normalize using sqrt(g h) / r
uspec[1:] = uspec[1:] * self.radius / torch.sqrt(self.gravity * self.havg) uspec[1:] = uspec[1:] * self.radius / torch.sqrt(self.gravity * self.havg)
return uspec return uspec
def dudtspec(self, uspec): def dudtspec(self, uspec):
""" """Compute time derivatives from solution represented in spectral coefficients."""
Compute time derivatives from solution represented in spectral coefficients
"""
dudtspec = torch.zeros_like(uspec) dudtspec = torch.zeros_like(uspec)
# compute the derivatives - this should be incorporated into the solver: # compute the derivatives - this should be incorporated into the solver:
...@@ -190,12 +203,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -190,12 +203,7 @@ class ShallowWaterSolver(nn.Module):
return dudtspec return dudtspec
def galewsky_initial_condition(self): def galewsky_initial_condition(self):
""" """Initialize non-linear barotropically unstable shallow water test case."""
Initializes non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
"""
device = self.lap.device device = self.lap.device
umax = 80. umax = 80.
...@@ -233,9 +241,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -233,9 +241,7 @@ class ShallowWaterSolver(nn.Module):
return torch.tril(uspec) return torch.tril(uspec)
def random_initial_condition(self, mach=0.1) -> torch.Tensor: def random_initial_condition(self, mach=0.1) -> torch.Tensor:
""" """Generate random initial condition on the sphere."""
random initial condition on the sphere
"""
device = self.lap.device device = self.lap.device
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64 ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
...@@ -281,10 +287,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -281,10 +287,7 @@ class ShallowWaterSolver(nn.Module):
return torch.tril(uspec) return torch.tril(uspec)
def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor: def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor:
""" """Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps."""
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
"""
dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device) dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device)
# pointers to indicate the most current result # pointers to indicate the most current result
...@@ -316,6 +319,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -316,6 +319,7 @@ class ShallowWaterSolver(nn.Module):
return uspec return uspec
def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0): def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0):
"""Integrate the solution on the grid."""
dlon = 2 * torch.pi / self.nlon dlon = 2 * torch.pi / self.nlon
radius = 1 if dimensionless else self.radius radius = 1 if dimensionless else self.radius
if polar_opt > 0: if polar_opt > 0:
...@@ -326,9 +330,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -326,9 +330,7 @@ class ShallowWaterSolver(nn.Module):
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False): def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
""" """Plotting routine for data on the grid. Requires cartopy for 3d plots."""
plotting routine for data on the grid. Requires cartopy for 3d plots.
"""
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
lons = self.lons.squeeze() - torch.pi lons = self.lons.squeeze() - torch.pi
......
...@@ -58,8 +58,22 @@ class Stanford2D3DSDownloader: ...@@ -58,8 +58,22 @@ class Stanford2D3DSDownloader:
""" """
Convenience class for downloading the 2d3ds dataset [1]. Convenience class for downloading the 2d3ds dataset [1].
Parameters
----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
Returns
-------
data_folders : list
List of extracted directory names
class_labels : list
List of semantic class labels
References References
----------- ----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.; .. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017). "Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105. https://arxiv.org/abs/1702.01105.
...@@ -106,6 +120,7 @@ class Stanford2D3DSDownloader: ...@@ -106,6 +120,7 @@ class Stanford2D3DSDownloader:
return local_path return local_path
def _extract_tar(self, tar_path): def _extract_tar(self, tar_path):
import tarfile import tarfile
with tarfile.open(tar_path) as tar: with tarfile.open(tar_path) as tar:
...@@ -116,7 +131,20 @@ class Stanford2D3DSDownloader: ...@@ -116,7 +131,20 @@ class Stanford2D3DSDownloader:
return extracted_dir return extracted_dir
def download_dataset(self, file_extracted_directory_pairs=DEFAULT_TAR_FILE_PAIRS): def download_dataset(self, file_extracted_directory_pairs=DEFAULT_TAR_FILE_PAIRS):
"""
Download and extract the complete dataset.
Parameters
-----------
file_extracted_directory_pairs : list, optional
List of (filename, extracted_folder_name) pairs, by default DEFAULT_TAR_FILE_PAIRS
Returns
-------
tuple
(data_folders, class_labels) where data_folders is a list of extracted directory names
and class_labels is the semantic label mapping
"""
import requests import requests
data_folders = [] data_folders = []
...@@ -133,6 +161,7 @@ class Stanford2D3DSDownloader: ...@@ -133,6 +161,7 @@ class Stanford2D3DSDownloader:
return data_folders, class_labels return data_folders, class_labels
def _rgb_to_id(self, img, class_labels_map, class_labels_indices): def _rgb_to_id(self, img, class_labels_map, class_labels_indices):
# Convert to int32 first to avoid overflow # Convert to int32 first to avoid overflow
r = img[..., 0].astype(np.int32) r = img[..., 0].astype(np.int32)
g = img[..., 1].astype(np.int32) g = img[..., 1].astype(np.int32)
...@@ -167,7 +196,35 @@ class Stanford2D3DSDownloader: ...@@ -167,7 +196,35 @@ class Stanford2D3DSDownloader:
downsampling_factor: int = 16, downsampling_factor: int = 16,
remove_alpha_channel: bool = True, remove_alpha_channel: bool = True,
): ):
"""
Convert the downloaded dataset to HDF5 format for efficient loading.
Parameters
-----------
data_folders : list
List of extracted data folder names
class_labels : list
List of semantic class labels
rgb_path : str, optional
Relative path to RGB images within each data folder, by default "pano/rgb"
semantic_path : str, optional
Relative path to semantic labels within each data folder, by default "pano/semantic"
depth_path : str, optional
Relative path to depth images within each data folder, by default "pano/depth"
output_filename : str, optional
Suffix for semantic label files, by default "semantic"
dataset_file : str, optional
Output HDF5 filename, by default "stanford_2d3ds_dataset.h5"
downsampling_factor : int, optional
Factor by which to downsample images, by default 16
remove_alpha_channel : bool, optional
Whether to remove alpha channel from RGB images, by default True
Returns
-------
str
Path to the created HDF5 dataset file
"""
converted_dataset_path = os.path.join(self.local_dir, dataset_file) converted_dataset_path = os.path.join(self.local_dir, dataset_file)
from PIL import Image from PIL import Image
...@@ -391,8 +448,24 @@ class StanfordSegmentationDataset(Dataset): ...@@ -391,8 +448,24 @@ class StanfordSegmentationDataset(Dataset):
""" """
Spherical segmentation dataset from [1]. Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
Returns
-------
StanfordSegmentationDataset
Dataset object
References References
----------- ----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.; .. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017). "Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105. https://arxiv.org/abs/1702.01105.
...@@ -528,8 +601,19 @@ class StanfordDepthDataset(Dataset): ...@@ -528,8 +601,19 @@ class StanfordDepthDataset(Dataset):
""" """
Spherical segmentation dataset from [1]. Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
References References
----------- ----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.; .. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017). "Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105. https://arxiv.org/abs/1702.01105.
...@@ -620,7 +704,8 @@ class StanfordDepthDataset(Dataset): ...@@ -620,7 +704,8 @@ class StanfordDepthDataset(Dataset):
def compute_stats_s2(dataset: Dataset, normalize_target: bool = False): def compute_stats_s2(dataset: Dataset, normalize_target: bool = False):
""" """
Compute stats using parallel welford reduction and quadrature on the sphere. The parallel welford reduction follows this article (parallel algorithm): https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Compute stats using parallel welford reduction and quadrature on the sphere.
The parallel welford reduction follows this article (parallel algorithm): https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
""" """
nexamples = len(dataset) nexamples = len(dataset)
......
...@@ -39,24 +39,19 @@ from torch_harmonics.cache import lru_cache ...@@ -39,24 +39,19 @@ from torch_harmonics.cache import lru_cache
def _circle_dist(x1: torch.Tensor, x2: torch.Tensor): def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
"""Helper function to compute the distance on a circle"""
return torch.minimum(torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2))) return torch.minimum(torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2)))
def _log_factorial(x: torch.Tensor): def _log_factorial(x: torch.Tensor):
"""Helper function to compute the log factorial on a torch tensor"""
return torch.lgamma(x + 1) return torch.lgamma(x + 1)
def _factorial(x: torch.Tensor): def _factorial(x: torch.Tensor):
"""Helper function to compute the factorial on a torch tensor"""
return torch.exp(_log_factorial(x)) return torch.exp(_log_factorial(x))
class FilterBasis(metaclass=abc.ABCMeta): class FilterBasis(metaclass=abc.ABCMeta):
""" """Abstract base class for a filter basis"""
Abstract base class for a filter basis
"""
def __init__( def __init__(
self, self,
...@@ -68,6 +63,7 @@ class FilterBasis(metaclass=abc.ABCMeta): ...@@ -68,6 +63,7 @@ class FilterBasis(metaclass=abc.ABCMeta):
@property @property
@abc.abstractmethod @abc.abstractmethod
def kernel_size(self): def kernel_size(self):
raise NotImplementedError raise NotImplementedError
# @abc.abstractmethod # @abc.abstractmethod
...@@ -79,10 +75,7 @@ class FilterBasis(metaclass=abc.ABCMeta): ...@@ -79,10 +75,7 @@ class FilterBasis(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the kernel's support and returns both indices and values.
This routine is designed for sparse evaluations of the filter basis.
"""
raise NotImplementedError raise NotImplementedError
...@@ -101,9 +94,7 @@ def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basi ...@@ -101,9 +94,7 @@ def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basi
class PiecewiseLinearFilterBasis(FilterBasis): class PiecewiseLinearFilterBasis(FilterBasis):
""" """Tensor-product basis on a disk constructed from piecewise linear basis functions."""
Tensor-product basis on a disk constructed from piecewise linear basis functions.
"""
def __init__( def __init__(
self, self,
...@@ -121,12 +112,10 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -121,12 +112,10 @@ class PiecewiseLinearFilterBasis(FilterBasis):
@property @property
def kernel_size(self): def kernel_size(self):
"""Compute the number of basis functions in the kernel."""
return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2 return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function # enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
...@@ -148,9 +137,6 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -148,9 +137,6 @@ class PiecewiseLinearFilterBasis(FilterBasis):
return iidx, vals return iidx, vals
def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function # enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
...@@ -209,6 +195,7 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -209,6 +195,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
return iidx, vals return iidx, vals
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""Computes the index set that falls into the kernel's support and returns both indices and values."""
if self.kernel_shape[1] > 1: if self.kernel_shape[1] > 1:
return self._compute_support_vals_anisotropic(r, phi, r_cutoff=r_cutoff) return self._compute_support_vals_anisotropic(r, phi, r_cutoff=r_cutoff)
...@@ -217,9 +204,7 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -217,9 +204,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
class MorletFilterBasis(FilterBasis): class MorletFilterBasis(FilterBasis):
""" """Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions."""
Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions
"""
def __init__( def __init__(
self, self,
...@@ -235,18 +220,18 @@ class MorletFilterBasis(FilterBasis): ...@@ -235,18 +220,18 @@ class MorletFilterBasis(FilterBasis):
@property @property
def kernel_size(self): def kernel_size(self):
return self.kernel_shape[0] * self.kernel_shape[1] return self.kernel_shape[0] * self.kernel_shape[1]
def gaussian_window(self, r: torch.Tensor, width: float = 1.0): def gaussian_window(self, r: torch.Tensor, width: float = 1.0):
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2)) return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def hann_window(self, r: torch.Tensor, width: float = 1.0): def hann_window(self, r: torch.Tensor, width: float = 1.0):
return torch.cos(0.5 * torch.pi * r / width) ** 2 return torch.cos(0.5 * torch.pi * r / width) ** 2
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function # enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
...@@ -274,9 +259,7 @@ class MorletFilterBasis(FilterBasis): ...@@ -274,9 +259,7 @@ class MorletFilterBasis(FilterBasis):
class ZernikeFilterBasis(FilterBasis): class ZernikeFilterBasis(FilterBasis):
""" """Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials"""
Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials
"""
def __init__( def __init__(
self, self,
...@@ -292,9 +275,11 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -292,9 +275,11 @@ class ZernikeFilterBasis(FilterBasis):
@property @property
def kernel_size(self): def kernel_size(self):
return (self.kernel_shape * (self.kernel_shape + 1)) // 2 return (self.kernel_shape * (self.kernel_shape + 1)) // 2
def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor): def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor):
out = torch.zeros_like(r) out = torch.zeros_like(r)
bound = (n - m) // 2 + 1 bound = (n - m) // 2 + 1
max_bound = bound.max().item() max_bound = bound.max().item()
...@@ -307,13 +292,11 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -307,13 +292,11 @@ class ZernikeFilterBasis(FilterBasis):
return out return out
def zernikepoly(self, r: torch.Tensor, phi: torch.Tensor, n: torch.Tensor, l: torch.Tensor): def zernikepoly(self, r: torch.Tensor, phi: torch.Tensor, n: torch.Tensor, l: torch.Tensor):
m = 2 * l - n m = 2 * l - n
return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi)) return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi))
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function # enumerator for basis function
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
......
...@@ -37,18 +37,37 @@ from torch_harmonics.cache import lru_cache ...@@ -37,18 +37,37 @@ from torch_harmonics.cache import lru_cache
def clm(l: int, m: int) -> float: def clm(l: int, m: int) -> float:
""" """Defines the normalization factor to orthonormalize the Spherical Harmonics."""
defines the normalization factor to orthonormalize the Spherical Harmonics
"""
return math.sqrt((2*l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l-m) / math.factorial(l+m)) return math.sqrt((2*l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l-m) / math.factorial(l+m))
def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor: def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r""" """
Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x. Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally. can be turned off optionally.
method of computation follows Parameters
-----------
mmax: int
Maximum order of the spherical harmonics
lmax: int
Maximum degree of the spherical harmonics
x: torch.Tensor
Tensor of positions at which to evaluate the Legendre polynomials
norm: Optional[str]
Normalization of the Legendre polynomials
inverse: Optional[bool]
Whether to compute the inverse Legendre polynomials
csphase: Optional[bool]
Whether to apply the Condon-Shortley phase (-1)^m
Returns
-------
out: torch.Tensor
Tensor of Legendre polynomial values
References
----------
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982; [2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982;
https://apps.dtic.mil/sti/citations/ADA123406 https://apps.dtic.mil/sti/citations/ADA123406
...@@ -94,12 +113,33 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", ...@@ -94,12 +113,33 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
@lru_cache(typed=True, copy=True) @lru_cache(typed=True, copy=True)
def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor, def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor: norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r""" """
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta). Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta).
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally. can be turned off optionally.
method of computation follows Parameters
-----------
mmax: int
Maximum order of the spherical harmonics
lmax: int
Maximum degree of the spherical harmonics
t: torch.Tensor
Tensor of positions at which to evaluate the Legendre polynomials
norm: Optional[str]
Normalization of the Legendre polynomials
inverse: Optional[bool]
Whether to compute the inverse Legendre polynomials
csphase: Optional[bool]
Whether to apply the Condon-Shortley phase (-1)^m
Returns
-------
out: torch.Tensor
Tensor of Legendre polynomial values
References
----------
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. [1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982; [2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982;
https://apps.dtic.mil/sti/citations/ADA123406 https://apps.dtic.mil/sti/citations/ADA123406
...@@ -111,13 +151,34 @@ def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor, ...@@ -111,13 +151,34 @@ def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
@lru_cache(typed=True, copy=True) @lru_cache(typed=True, copy=True)
def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor, def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor: norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r""" """
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$ Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$, at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
needed for the computation of the vector spherical harmonics. The resulting tensor has shape needed for the computation of the vector spherical harmonics. The resulting tensor has shape
(2, mmax, lmax, len(t)). (2, mmax, lmax, len(t)).
computation follows Parameters
-----------
mmax: int
Maximum order of the spherical harmonics
lmax: int
Maximum degree of the spherical harmonics
t: torch.Tensor
Tensor of positions at which to evaluate the Legendre polynomials
norm: Optional[str]
Normalization of the Legendre polynomials
inverse: Optional[bool]
Whether to compute the inverse Legendre polynomials
csphase: Optional[bool]
Whether to apply the Condon-Shortley phase (-1)^m
Returns
-------
out: torch.Tensor
Tensor of Legendre polynomial values
References
----------
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
""" """
......
...@@ -58,6 +58,28 @@ def get_projection( ...@@ -58,6 +58,28 @@ def get_projection(
central_latitude=0, central_latitude=0,
central_longitude=0, central_longitude=0,
): ):
"""
Get a cartopy projection object for map plotting.
Parameters
-----------
projection : str
Projection type ("orthographic", "robinson", "platecarree", "mollweide")
central_latitude : float, optional
Central latitude for the projection, by default 0
central_longitude : float, optional
Central longitude for the projection, by default 0
Returns
-------
cartopy.crs.Projection
Cartopy projection object
Raises
------
ValueError
If projection type is not supported
"""
if projection == "orthographic": if projection == "orthographic":
proj = ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude) proj = ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude)
elif projection == "robinson": elif projection == "robinson":
...@@ -77,6 +99,40 @@ def plot_sphere( ...@@ -77,6 +99,40 @@ def plot_sphere(
): ):
""" """
Plots a function defined on the sphere using pcolormesh Plots a function defined on the sphere using pcolormesh
Parameters
-----------
data : numpy.ndarray or torch.Tensor
Data to plot with shape (nlat, nlon)
fig : matplotlib.figure.Figure, optional
Figure to plot on, by default None (creates new figure)
projection : str, optional
Map projection type, by default "robinson"
cmap : str, optional
Colormap name, by default "RdBu"
title : str, optional
Plot title, by default None
colorbar : bool, optional
Whether to add a colorbar, by default False
coastlines : bool, optional
Whether to add coastlines, by default False
gridlines : bool, optional
Whether to add gridlines, by default False
central_latitude : float, optional
Central latitude for projection, by default 0
central_longitude : float, optional
Central longitude for projection, by default 0
lon : numpy.ndarray, optional
Longitude coordinates, by default None (auto-generated)
lat : numpy.ndarray, optional
Latitude coordinates, by default None (auto-generated)
**kwargs
Additional arguments passed to pcolormesh
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
""" """
# make sure cartopy exist # make sure cartopy exist
...@@ -126,6 +182,28 @@ def plot_sphere( ...@@ -126,6 +182,28 @@ def plot_sphere(
def imshow_sphere(data, fig=None, projection="robinson", title=None, central_latitude=0, central_longitude=0, **kwargs): def imshow_sphere(data, fig=None, projection="robinson", title=None, central_latitude=0, central_longitude=0, **kwargs):
""" """
Displays an image on the sphere Displays an image on the sphere
Parameters
-----------
data : numpy.ndarray or torch.Tensor
Data to display with shape (nlat, nlon)
fig : matplotlib.figure.Figure, optional
Figure to plot on, by default None (creates new figure)
projection : str, optional
Map projection type, by default "robinson"
title : str, optional
Plot title, by default None
central_latitude : float, optional
Central latitude for projection, by default 0
central_longitude : float, optional
Central longitude for projection, by default 0
**kwargs
Additional arguments passed to imshow
Returns
-------
matplotlib.image.AxesImage
The displayed image object
""" """
# make sure cartopy exist # make sure cartopy exist
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment