Commit e4879676 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Added docstrings to many methods

parent b5c410c0
......@@ -54,6 +54,46 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
class DownsamplingBlock(nn.Module):
"""
Downsampling block for spherical U-Net architecture.
This block performs convolution operations followed by downsampling on spherical data,
using discrete-continuous convolutions to maintain spectral properties.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of convolution repetitions, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connection, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.0
drop_path_rate : float, optional
Drop path rate, by default 0.0
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.0
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__(
self,
in_shape,
......@@ -154,12 +194,33 @@ class DownsamplingBlock(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the downsampling block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Downsampled tensor
"""
# skip connection
residual = x
if hasattr(self, "transform_skip"):
......@@ -178,6 +239,46 @@ class DownsamplingBlock(nn.Module):
class UpsamplingBlock(nn.Module):
"""
Upsampling block for spherical U-Net architecture.
This block performs upsampling followed by convolution operations on spherical data,
using discrete-continuous convolutions to maintain spectral properties.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of convolution repetitions, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connection, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.0
drop_path_rate : float, optional
Drop path rate, by default 0.0
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.0
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__(
self,
in_shape,
......@@ -496,6 +597,14 @@ class SphericalUNet(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
......@@ -505,7 +614,19 @@ class SphericalUNet(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""
Forward pass through the complete spherical U-Net model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder:
features = []
feat = x
......
......@@ -118,9 +118,20 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else:
raise ValueError(f"Unknown skip connection type {outer_skip}")
def forward(self, x):
"""
Forward pass through the SFNO block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after processing through the block
"""
x, residual = self.global_conv(x)
x = self.norm(x)
......@@ -147,8 +158,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Parameters
----------
img_shape : tuple, optional
img_size : tuple, optional
Shape of the input channels, by default (128, 256)
grid : str, optional
Input grid type, by default "equiangular"
grid_internal : str, optional
Internal grid type for computations, by default "legendre-gauss"
scale_factor : int, optional
Scale factor to use, by default 3
in_chans : int, optional
......@@ -172,20 +187,20 @@ class SphericalFourierNeuralOperator(nn.Module):
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "none"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
residual_prediction : bool, optional
Whether to add a single large skip connection, by default True
pos_embed : bool, optional
Whether to use positional embedding, by default True
Whether to add a single large skip connection, by default False
pos_embed : str, optional
Type of positional embedding to use, by default "none"
bias : bool, optional
Whether to use a bias, by default False
Example:
--------
>>> model = SphericalFourierNeuralOperator(
... img_shape=(128, 256),
... img_size=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
......@@ -355,10 +370,30 @@ class SphericalFourierNeuralOperator(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
"""
Return a set of parameter names that should not be decayed.
Returns
-------
set
Set of parameter names to exclude from weight decay
"""
return {"pos_embed.pos_embed"}
def forward_features(self, x):
"""
Forward pass through the feature extraction layers.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Features after processing through the network
"""
x = self.pos_drop(x)
for blk in self.blocks:
......@@ -367,7 +402,19 @@ class SphericalFourierNeuralOperator(nn.Module):
return x
def forward(self, x):
"""
Forward pass through the complete SFNO model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if self.residual_prediction:
residual = x
......
......@@ -42,7 +42,27 @@ import numpy as np
class SphereSolver(nn.Module):
"""
Solver class on the sphere. Can solve the following PDEs:
- Allen-Cahn eq
- Allen-Cahn equation
- Ginzburg-Landau equation
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
dt : float
Time step size
lmax : int, optional
Maximum l mode for spherical harmonics, by default None
mmax : int, optional
Maximum m mode for spherical harmonics, by default None
grid : str, optional
Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
radius : float, optional
Radius of the sphere, by default 1.0
coeff : float, optional
Coefficient for the PDE, by default 0.001
"""
def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=1.0, coeff=0.001):
......@@ -97,17 +117,58 @@ class SphereSolver(nn.Module):
self.register_buffer('invlap', invlap)
def grid2spec(self, u):
"""spectral coefficients from spatial data"""
"""
Convert spatial data to spectral coefficients.
Parameters
-----------
u : torch.Tensor
Spatial data tensor
Returns
-------
torch.Tensor
Spectral coefficients
"""
return self.sht(u)
def spec2grid(self, uspec):
"""spatial data from spectral coefficients"""
"""
Convert spectral coefficients to spatial data.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients tensor
Returns
-------
torch.Tensor
Spatial data
"""
return self.isht(uspec)
def dudtspec(self, uspec, pde='allen-cahn'):
"""
Compute the time derivative of spectral coefficients for different PDEs.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients
pde : str, optional
PDE type ("allen-cahn", "ginzburg-landau"), by default "allen-cahn"
Returns
-------
torch.Tensor
Time derivative of spectral coefficients
Raises
------
NotImplementedError
If PDE type is not supported
"""
if pde == 'allen-cahn':
ugrid = self.spec2grid(uspec)
u3spec = self.grid2spec(ugrid**3)
......@@ -117,20 +178,55 @@ class SphereSolver(nn.Module):
u3spec = self.grid2spec(ugrid**3)
dudtspec = uspec + (1. + 2.j)*self.coeff*self.lap*uspec - (1. + 2.j)*u3spec
else:
NotImplementedError
raise NotImplementedError(f"PDE type {pde} not implemented")
return dudtspec
def randspec(self):
"""random data on the sphere"""
"""
Generate random spectral data on the sphere.
Returns
-------
torch.Tensor
Random spectral coefficients
"""
rspec = torch.randn_like(self.lap) / 4 / torch.pi
return rspec
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
"""
plotting routine for data on the grid. Requires cartopy for 3d plots.
Plot data on the sphere grid. Requires cartopy for 3d plots.
Parameters
-----------
data : torch.Tensor
Data to plot
fig : matplotlib.figure.Figure
Figure to plot on
cmap : str, optional
Colormap name, by default 'twilight_shifted'
vmax : float, optional
Maximum value for color scaling, by default None
vmin : float, optional
Minimum value for color scaling, by default None
projection : str, optional
Projection type ("mollweide", "3d"), by default "3d"
title : str, optional
Plot title, by default None
antialiased : bool, optional
Whether to use antialiasing, by default False
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
Raises
------
NotImplementedError
If projection type is not supported
"""
import matplotlib.pyplot as plt
......@@ -172,9 +268,26 @@ class SphereSolver(nn.Module):
plt.title(title, y=1.05)
else:
raise NotImplementedError
raise NotImplementedError(f"Projection {projection} not implemented")
return im
def plot_specdata(self, data, fig, **kwargs):
"""
Plot spectral data by converting to spatial data first.
Parameters
-----------
data : torch.Tensor
Spectral data to plot
fig : matplotlib.figure.Figure
Figure to plot on
**kwargs
Additional arguments passed to plot_griddata
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
"""
return self.plot_griddata(self.isht(data), fig, **kwargs)
......@@ -41,7 +41,35 @@ import numpy as np
class ShallowWaterSolver(nn.Module):
"""
SWE solver class. Interface inspired bu pyspharm and SHTns
Shallow Water Equations (SWE) solver class for spherical geometry.
Interface inspired by pyspharm and SHTns. Solves the shallow water equations
on a rotating sphere using spectral methods.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
dt : float
Time step size
lmax : int, optional
Maximum l mode for spherical harmonics, by default None
mmax : int, optional
Maximum m mode for spherical harmonics, by default None
grid : str, optional
Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
radius : float, optional
Radius of the sphere in meters, by default 6.37122E6 (Earth radius)
omega : float, optional
Angular velocity of rotation in rad/s, by default 7.292E-5 (Earth)
gravity : float, optional
Gravitational acceleration in m/s², by default 9.80616
havg : float, optional
Average height in meters, by default 10.e3
hamp : float, optional
Height amplitude in meters, by default 120.
"""
def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=6.37122E6, \
......@@ -115,30 +143,82 @@ class ShallowWaterSolver(nn.Module):
def grid2spec(self, ugrid):
"""
spectral coefficients from spatial data
Convert spatial data to spectral coefficients.
Parameters
-----------
ugrid : torch.Tensor
Spatial data tensor
Returns
-------
torch.Tensor
Spectral coefficients
"""
return self.sht(ugrid)
def spec2grid(self, uspec):
"""
spatial data from spectral coefficients
Convert spectral coefficients to spatial data.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients tensor
Returns
-------
torch.Tensor
Spatial data
"""
return self.isht(uspec)
def vrtdivspec(self, ugrid):
"""spatial data from spectral coefficients"""
"""
Compute vorticity and divergence from velocity field.
Parameters
-----------
ugrid : torch.Tensor
Velocity field in spatial coordinates
Returns
-------
torch.Tensor
Spectral coefficients of vorticity and divergence
"""
vrtdivspec = self.lap * self.radius * self.vsht(ugrid)
return vrtdivspec
def getuv(self, vrtdivspec):
"""
compute wind vector from spectral coeffs of vorticity and divergence
Compute wind vector from spectral coefficients of vorticity and divergence.
Parameters
-----------
vrtdivspec : torch.Tensor
Spectral coefficients of vorticity and divergence
Returns
-------
torch.Tensor
Wind vector field in spatial coordinates
"""
return self.ivsht( self.invlap * vrtdivspec / self.radius)
def gethuv(self, uspec):
"""
compute wind vector from spectral coeffs of vorticity and divergence
Compute height and wind vector from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Combined height and wind vector field
"""
hgrid = self.spec2grid(uspec[:1])
uvgrid = self.getuv(uspec[1:])
......@@ -146,7 +226,17 @@ class ShallowWaterSolver(nn.Module):
def potential_vorticity(self, uspec):
"""
Compute potential vorticity
Compute potential vorticity from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Potential vorticity field
"""
ugrid = self.spec2grid(uspec)
pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0]
......@@ -154,7 +244,17 @@ class ShallowWaterSolver(nn.Module):
def dimensionless(self, uspec):
"""
Remove dimensions from variables
Remove dimensions from variables for dimensionless analysis.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients with dimensions
Returns
-------
torch.Tensor
Dimensionless spectral coefficients
"""
uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r
......@@ -163,9 +263,18 @@ class ShallowWaterSolver(nn.Module):
def dudtspec(self, uspec):
"""
Compute time derivatives from solution represented in spectral coefficients
Compute time derivatives from solution represented in spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Time derivatives of spectral coefficients
"""
dudtspec = torch.zeros_like(uspec)
# compute the derivatives - this should be incorporated into the solver:
......@@ -191,10 +300,15 @@ class ShallowWaterSolver(nn.Module):
def galewsky_initial_condition(self):
"""
Initializes non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
Returns
-------
torch.Tensor
Initial spectral coefficients for the Galewsky test case
"""
device = self.lap.device
......@@ -234,7 +348,17 @@ class ShallowWaterSolver(nn.Module):
def random_initial_condition(self, mach=0.1) -> torch.Tensor:
"""
random initial condition on the sphere
Generate random initial condition on the sphere.
Parameters
-----------
mach : float, optional
Mach number for scaling the random perturbations, by default 0.1
Returns
-------
torch.Tensor
Random initial spectral coefficients
"""
device = self.lap.device
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
......
......@@ -66,13 +66,34 @@ class Stanford2D3DSDownloader:
"""
def __init__(self, base_url: str = DEFAULT_BASE_URL, local_dir: str = "data"):
"""
Initialize the Stanford 2D3DS dataset downloader.
Parameters
-----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
"""
self.base_url = base_url
self.local_dir = local_dir
os.makedirs(self.local_dir, exist_ok=True)
def _download_file(self, filename):
"""
Download a single file with progress bar and resume capability.
Parameters
-----------
filename : str
Name of the file to download
Returns
-------
str
Local path to the downloaded file
"""
import requests
from tqdm import tqdm
......@@ -106,6 +127,19 @@ class Stanford2D3DSDownloader:
return local_path
def _extract_tar(self, tar_path):
"""
Extract a tar file and return the extracted directory name.
Parameters
-----------
tar_path : str
Path to the tar file to extract
Returns
-------
str
Name of the extracted directory
"""
import tarfile
with tarfile.open(tar_path) as tar:
......@@ -116,7 +150,20 @@ class Stanford2D3DSDownloader:
return extracted_dir
def download_dataset(self, file_extracted_directory_pairs=DEFAULT_TAR_FILE_PAIRS):
"""
Download and extract the complete dataset.
Parameters
-----------
file_extracted_directory_pairs : list, optional
List of (filename, extracted_folder_name) pairs, by default DEFAULT_TAR_FILE_PAIRS
Returns
-------
tuple
(data_folders, class_labels) where data_folders is a list of extracted directory names
and class_labels is the semantic label mapping
"""
import requests
data_folders = []
......@@ -133,6 +180,23 @@ class Stanford2D3DSDownloader:
return data_folders, class_labels
def _rgb_to_id(self, img, class_labels_map, class_labels_indices):
"""
Convert RGB image to class ID using color mapping.
Parameters
-----------
img : numpy.ndarray
RGB image array
class_labels_map : list
Mapping from color values to class labels
class_labels_indices : list
List of class label indices
Returns
-------
numpy.ndarray
Class ID array
"""
# Convert to int32 first to avoid overflow
r = img[..., 0].astype(np.int32)
g = img[..., 1].astype(np.int32)
......@@ -167,7 +231,35 @@ class Stanford2D3DSDownloader:
downsampling_factor: int = 16,
remove_alpha_channel: bool = True,
):
"""
Convert the downloaded dataset to HDF5 format for efficient loading.
Parameters
-----------
data_folders : list
List of extracted data folder names
class_labels : list
List of semantic class labels
rgb_path : str, optional
Relative path to RGB images within each data folder, by default "pano/rgb"
semantic_path : str, optional
Relative path to semantic labels within each data folder, by default "pano/semantic"
depth_path : str, optional
Relative path to depth images within each data folder, by default "pano/depth"
output_filename : str, optional
Suffix for semantic label files, by default "semantic"
dataset_file : str, optional
Output HDF5 filename, by default "stanford_2d3ds_dataset.h5"
downsampling_factor : int, optional
Factor by which to downsample images, by default 16
remove_alpha_channel : bool, optional
Whether to remove alpha channel from RGB images, by default True
Returns
-------
str
Path to the created HDF5 dataset file
"""
converted_dataset_path = os.path.join(self.local_dir, dataset_file)
from PIL import Image
......
......@@ -62,12 +62,23 @@ class FilterBasis(metaclass=abc.ABCMeta):
self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
):
"""
Initialize the filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
self.kernel_shape = kernel_shape
@property
@abc.abstractmethod
def kernel_size(self):
"""
Abstract property that should return the size of the kernel.
Returns:
int: the kernel size
"""
raise NotImplementedError
# @abc.abstractmethod
......@@ -109,7 +120,12 @@ class PiecewiseLinearFilterBasis(FilterBasis):
self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
):
"""
Initialize the piecewise linear filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
if len(kernel_shape) == 1:
......@@ -121,6 +137,12 @@ class PiecewiseLinearFilterBasis(FilterBasis):
@property
def kernel_size(self):
"""
Compute the kernel size for piecewise linear basis.
Returns:
int: the kernel size
"""
return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
......@@ -225,7 +247,12 @@ class MorletFilterBasis(FilterBasis):
self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
):
"""
Initialize the Morlet filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape, kernel_shape]
if len(kernel_shape) != 2:
......@@ -235,12 +262,38 @@ class MorletFilterBasis(FilterBasis):
@property
def kernel_size(self):
"""
Compute the kernel size for Morlet basis.
Returns:
int: the kernel size
"""
return self.kernel_shape[0] * self.kernel_shape[1]
def gaussian_window(self, r: torch.Tensor, width: float = 1.0):
"""
Compute Gaussian window function.
Parameters:
r: radial distance tensor
width: width parameter of the Gaussian
Returns:
torch.Tensor: Gaussian window values
"""
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def hann_window(self, r: torch.Tensor, width: float = 1.0):
"""
Compute Hann window function.
Parameters:
r: radial distance tensor
width: width parameter of the Hann window
Returns:
torch.Tensor: Hann window values
"""
return torch.cos(0.5 * torch.pi * r / width) ** 2
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0):
......@@ -282,7 +335,12 @@ class ZernikeFilterBasis(FilterBasis):
self,
kernel_shape: Union[int, Tuple[int]],
):
"""
Initialize the Zernike filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list):
kernel_shape = kernel_shape[0]
if not isinstance(kernel_shape, int):
......@@ -292,9 +350,26 @@ class ZernikeFilterBasis(FilterBasis):
@property
def kernel_size(self):
"""
Compute the kernel size for Zernike basis.
Returns:
int: the kernel size
"""
return (self.kernel_shape * (self.kernel_shape + 1)) // 2
def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor):
"""
Compute radial Zernike polynomials.
Parameters:
r: radial distance tensor
n: principal quantum number
m: azimuthal quantum number
Returns:
torch.Tensor: radial Zernike polynomial values
"""
out = torch.zeros_like(r)
bound = (n - m) // 2 + 1
max_bound = bound.max().item()
......@@ -307,6 +382,18 @@ class ZernikeFilterBasis(FilterBasis):
return out
def zernikepoly(self, r: torch.Tensor, phi: torch.Tensor, n: torch.Tensor, l: torch.Tensor):
"""
Compute Zernike polynomials.
Parameters:
r: radial distance tensor
phi: azimuthal angle tensor
n: principal quantum number
l: azimuthal quantum number
Returns:
torch.Tensor: Zernike polynomial values
"""
m = 2 * l - n
return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi))
......
......@@ -47,6 +47,14 @@ except ImportError as err:
def check_plotting_dependencies():
"""
Check if required plotting dependencies (matplotlib and cartopy) are available.
Raises
------
ImportError
If matplotlib or cartopy is not installed
"""
if plt is None:
raise ImportError("matplotlib is required for plotting functions. Install it with 'pip install matplotlib'")
if cartopy is None:
......@@ -58,6 +66,28 @@ def get_projection(
central_latitude=0,
central_longitude=0,
):
"""
Get a cartopy projection object for map plotting.
Parameters
-----------
projection : str
Projection type ("orthographic", "robinson", "platecarree", "mollweide")
central_latitude : float, optional
Central latitude for the projection, by default 0
central_longitude : float, optional
Central longitude for the projection, by default 0
Returns
-------
cartopy.crs.Projection
Cartopy projection object
Raises
------
ValueError
If projection type is not supported
"""
if projection == "orthographic":
proj = ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude)
elif projection == "robinson":
......@@ -77,6 +107,40 @@ def plot_sphere(
):
"""
Plots a function defined on the sphere using pcolormesh
Parameters
-----------
data : numpy.ndarray or torch.Tensor
Data to plot with shape (nlat, nlon)
fig : matplotlib.figure.Figure, optional
Figure to plot on, by default None (creates new figure)
projection : str, optional
Map projection type, by default "robinson"
cmap : str, optional
Colormap name, by default "RdBu"
title : str, optional
Plot title, by default None
colorbar : bool, optional
Whether to add a colorbar, by default False
coastlines : bool, optional
Whether to add coastlines, by default False
gridlines : bool, optional
Whether to add gridlines, by default False
central_latitude : float, optional
Central latitude for projection, by default 0
central_longitude : float, optional
Central longitude for projection, by default 0
lon : numpy.ndarray, optional
Longitude coordinates, by default None (auto-generated)
lat : numpy.ndarray, optional
Latitude coordinates, by default None (auto-generated)
**kwargs
Additional arguments passed to pcolormesh
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
"""
# make sure cartopy exist
......@@ -126,6 +190,28 @@ def plot_sphere(
def imshow_sphere(data, fig=None, projection="robinson", title=None, central_latitude=0, central_longitude=0, **kwargs):
"""
Displays an image on the sphere
Parameters
-----------
data : numpy.ndarray or torch.Tensor
Data to display with shape (nlat, nlon)
fig : matplotlib.figure.Figure, optional
Figure to plot on, by default None (creates new figure)
projection : str, optional
Map projection type, by default "robinson"
title : str, optional
Plot title, by default None
central_latitude : float, optional
Central latitude for projection, by default 0
central_longitude : float, optional
Central longitude for projection, by default 0
**kwargs
Additional arguments passed to imshow
Returns
-------
matplotlib.image.AxesImage
The displayed image object
"""
# make sure cartopy exist
......
......@@ -37,6 +37,32 @@ import torch
def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0,
periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Precompute grid points and weights for various quadrature rules.
Parameters
-----------
n : int
Number of grid points
grid : str, optional
Grid type ("equidistant", "legendre-gauss", "lobatto", "equiangular"), by default "equidistant"
a : float, optional
Lower bound of interval, by default 0.0
b : float, optional
Upper bound of interval, by default 1.0
periodic : bool, optional
Whether the grid is periodic (only for equidistant), by default False
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Grid points and weights
Raises
------
ValueError
If periodic is True for non-equidistant grids or unknown grid type
"""
if (grid != "equidistant") and periodic:
raise ValueError(f"Periodic grid is only supported on equidistant grids.")
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment