Commit e4879676 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Added docstrings to many methods

parent b5c410c0
...@@ -54,6 +54,46 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type): ...@@ -54,6 +54,46 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
class DownsamplingBlock(nn.Module): class DownsamplingBlock(nn.Module):
"""
Downsampling block for spherical U-Net architecture.
This block performs convolution operations followed by downsampling on spherical data,
using discrete-continuous convolutions to maintain spectral properties.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of convolution repetitions, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connection, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.0
drop_path_rate : float, optional
Drop path rate, by default 0.0
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.0
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -154,12 +194,33 @@ class DownsamplingBlock(nn.Module): ...@@ -154,12 +194,33 @@ class DownsamplingBlock(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the downsampling block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Downsampled tensor
"""
# skip connection # skip connection
residual = x residual = x
if hasattr(self, "transform_skip"): if hasattr(self, "transform_skip"):
...@@ -178,6 +239,46 @@ class DownsamplingBlock(nn.Module): ...@@ -178,6 +239,46 @@ class DownsamplingBlock(nn.Module):
class UpsamplingBlock(nn.Module): class UpsamplingBlock(nn.Module):
"""
Upsampling block for spherical U-Net architecture.
This block performs upsampling followed by convolution operations on spherical data,
using discrete-continuous convolutions to maintain spectral properties.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of convolution repetitions, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connection, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.0
drop_path_rate : float, optional
Drop path rate, by default 0.0
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.0
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -496,6 +597,14 @@ class SphericalUNet(nn.Module): ...@@ -496,6 +597,14 @@ class SphericalUNet(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02) nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: if m.bias is not None:
...@@ -505,7 +614,19 @@ class SphericalUNet(nn.Module): ...@@ -505,7 +614,19 @@ class SphericalUNet(nn.Module):
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
def forward(self, x): def forward(self, x):
"""
Forward pass through the complete spherical U-Net model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder: # encoder:
features = [] features = []
feat = x feat = x
......
...@@ -118,9 +118,20 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -118,9 +118,20 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else: else:
raise ValueError(f"Unknown skip connection type {outer_skip}") raise ValueError(f"Unknown skip connection type {outer_skip}")
def forward(self, x): def forward(self, x):
"""
Forward pass through the SFNO block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after processing through the block
"""
x, residual = self.global_conv(x) x, residual = self.global_conv(x)
x = self.norm(x) x = self.norm(x)
...@@ -147,8 +158,12 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -147,8 +158,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Parameters Parameters
---------- ----------
img_shape : tuple, optional img_size : tuple, optional
Shape of the input channels, by default (128, 256) Shape of the input channels, by default (128, 256)
grid : str, optional
Input grid type, by default "equiangular"
grid_internal : str, optional
Internal grid type for computations, by default "legendre-gauss"
scale_factor : int, optional scale_factor : int, optional
Scale factor to use, by default 3 Scale factor to use, by default 3
in_chans : int, optional in_chans : int, optional
...@@ -172,20 +187,20 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -172,20 +187,20 @@ class SphericalFourierNeuralOperator(nn.Module):
drop_path_rate : float, optional drop_path_rate : float, optional
Dropout path rate, by default 0.0 Dropout path rate, by default 0.0
normalization_layer : str, optional normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "none"
hard_thresholding_fraction : float, optional hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0 Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
residual_prediction : bool, optional residual_prediction : bool, optional
Whether to add a single large skip connection, by default True Whether to add a single large skip connection, by default False
pos_embed : bool, optional pos_embed : str, optional
Whether to use positional embedding, by default True Type of positional embedding to use, by default "none"
bias : bool, optional bias : bool, optional
Whether to use a bias, by default False Whether to use a bias, by default False
Example: Example:
-------- --------
>>> model = SphericalFourierNeuralOperator( >>> model = SphericalFourierNeuralOperator(
... img_shape=(128, 256), ... img_size=(128, 256),
... scale_factor=4, ... scale_factor=4,
... in_chans=2, ... in_chans=2,
... out_chans=2, ... out_chans=2,
...@@ -355,10 +370,30 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -355,10 +370,30 @@ class SphericalFourierNeuralOperator(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {"pos_embed", "cls_token"} """
Return a set of parameter names that should not be decayed.
Returns
-------
set
Set of parameter names to exclude from weight decay
"""
return {"pos_embed.pos_embed"}
def forward_features(self, x): def forward_features(self, x):
"""
Forward pass through the feature extraction layers.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Features after processing through the network
"""
x = self.pos_drop(x) x = self.pos_drop(x)
for blk in self.blocks: for blk in self.blocks:
...@@ -367,7 +402,19 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -367,7 +402,19 @@ class SphericalFourierNeuralOperator(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
"""
Forward pass through the complete SFNO model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if self.residual_prediction: if self.residual_prediction:
residual = x residual = x
......
...@@ -42,7 +42,27 @@ import numpy as np ...@@ -42,7 +42,27 @@ import numpy as np
class SphereSolver(nn.Module): class SphereSolver(nn.Module):
""" """
Solver class on the sphere. Can solve the following PDEs: Solver class on the sphere. Can solve the following PDEs:
- Allen-Cahn eq - Allen-Cahn equation
- Ginzburg-Landau equation
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
dt : float
Time step size
lmax : int, optional
Maximum l mode for spherical harmonics, by default None
mmax : int, optional
Maximum m mode for spherical harmonics, by default None
grid : str, optional
Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
radius : float, optional
Radius of the sphere, by default 1.0
coeff : float, optional
Coefficient for the PDE, by default 0.001
""" """
def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=1.0, coeff=0.001): def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=1.0, coeff=0.001):
...@@ -97,17 +117,58 @@ class SphereSolver(nn.Module): ...@@ -97,17 +117,58 @@ class SphereSolver(nn.Module):
self.register_buffer('invlap', invlap) self.register_buffer('invlap', invlap)
def grid2spec(self, u): def grid2spec(self, u):
"""spectral coefficients from spatial data""" """
Convert spatial data to spectral coefficients.
Parameters
-----------
u : torch.Tensor
Spatial data tensor
Returns
-------
torch.Tensor
Spectral coefficients
"""
return self.sht(u) return self.sht(u)
def spec2grid(self, uspec): def spec2grid(self, uspec):
"""spatial data from spectral coefficients""" """
Convert spectral coefficients to spatial data.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients tensor
Returns
-------
torch.Tensor
Spatial data
"""
return self.isht(uspec) return self.isht(uspec)
def dudtspec(self, uspec, pde='allen-cahn'): def dudtspec(self, uspec, pde='allen-cahn'):
"""
Compute the time derivative of spectral coefficients for different PDEs.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients
pde : str, optional
PDE type ("allen-cahn", "ginzburg-landau"), by default "allen-cahn"
Returns
-------
torch.Tensor
Time derivative of spectral coefficients
Raises
------
NotImplementedError
If PDE type is not supported
"""
if pde == 'allen-cahn': if pde == 'allen-cahn':
ugrid = self.spec2grid(uspec) ugrid = self.spec2grid(uspec)
u3spec = self.grid2spec(ugrid**3) u3spec = self.grid2spec(ugrid**3)
...@@ -117,20 +178,55 @@ class SphereSolver(nn.Module): ...@@ -117,20 +178,55 @@ class SphereSolver(nn.Module):
u3spec = self.grid2spec(ugrid**3) u3spec = self.grid2spec(ugrid**3)
dudtspec = uspec + (1. + 2.j)*self.coeff*self.lap*uspec - (1. + 2.j)*u3spec dudtspec = uspec + (1. + 2.j)*self.coeff*self.lap*uspec - (1. + 2.j)*u3spec
else: else:
NotImplementedError raise NotImplementedError(f"PDE type {pde} not implemented")
return dudtspec return dudtspec
def randspec(self): def randspec(self):
"""random data on the sphere""" """
Generate random spectral data on the sphere.
Returns
-------
torch.Tensor
Random spectral coefficients
"""
rspec = torch.randn_like(self.lap) / 4 / torch.pi rspec = torch.randn_like(self.lap) / 4 / torch.pi
return rspec return rspec
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False): def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
""" """
plotting routine for data on the grid. Requires cartopy for 3d plots. Plot data on the sphere grid. Requires cartopy for 3d plots.
Parameters
-----------
data : torch.Tensor
Data to plot
fig : matplotlib.figure.Figure
Figure to plot on
cmap : str, optional
Colormap name, by default 'twilight_shifted'
vmax : float, optional
Maximum value for color scaling, by default None
vmin : float, optional
Minimum value for color scaling, by default None
projection : str, optional
Projection type ("mollweide", "3d"), by default "3d"
title : str, optional
Plot title, by default None
antialiased : bool, optional
Whether to use antialiasing, by default False
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
Raises
------
NotImplementedError
If projection type is not supported
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -172,9 +268,26 @@ class SphereSolver(nn.Module): ...@@ -172,9 +268,26 @@ class SphereSolver(nn.Module):
plt.title(title, y=1.05) plt.title(title, y=1.05)
else: else:
raise NotImplementedError raise NotImplementedError(f"Projection {projection} not implemented")
return im return im
def plot_specdata(self, data, fig, **kwargs): def plot_specdata(self, data, fig, **kwargs):
"""
Plot spectral data by converting to spatial data first.
Parameters
-----------
data : torch.Tensor
Spectral data to plot
fig : matplotlib.figure.Figure
Figure to plot on
**kwargs
Additional arguments passed to plot_griddata
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
"""
return self.plot_griddata(self.isht(data), fig, **kwargs) return self.plot_griddata(self.isht(data), fig, **kwargs)
...@@ -41,7 +41,35 @@ import numpy as np ...@@ -41,7 +41,35 @@ import numpy as np
class ShallowWaterSolver(nn.Module): class ShallowWaterSolver(nn.Module):
""" """
SWE solver class. Interface inspired bu pyspharm and SHTns Shallow Water Equations (SWE) solver class for spherical geometry.
Interface inspired by pyspharm and SHTns. Solves the shallow water equations
on a rotating sphere using spectral methods.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
dt : float
Time step size
lmax : int, optional
Maximum l mode for spherical harmonics, by default None
mmax : int, optional
Maximum m mode for spherical harmonics, by default None
grid : str, optional
Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
radius : float, optional
Radius of the sphere in meters, by default 6.37122E6 (Earth radius)
omega : float, optional
Angular velocity of rotation in rad/s, by default 7.292E-5 (Earth)
gravity : float, optional
Gravitational acceleration in m/s², by default 9.80616
havg : float, optional
Average height in meters, by default 10.e3
hamp : float, optional
Height amplitude in meters, by default 120.
""" """
def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=6.37122E6, \ def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=6.37122E6, \
...@@ -115,30 +143,82 @@ class ShallowWaterSolver(nn.Module): ...@@ -115,30 +143,82 @@ class ShallowWaterSolver(nn.Module):
def grid2spec(self, ugrid): def grid2spec(self, ugrid):
""" """
spectral coefficients from spatial data Convert spatial data to spectral coefficients.
Parameters
-----------
ugrid : torch.Tensor
Spatial data tensor
Returns
-------
torch.Tensor
Spectral coefficients
""" """
return self.sht(ugrid) return self.sht(ugrid)
def spec2grid(self, uspec): def spec2grid(self, uspec):
""" """
spatial data from spectral coefficients Convert spectral coefficients to spatial data.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients tensor
Returns
-------
torch.Tensor
Spatial data
""" """
return self.isht(uspec) return self.isht(uspec)
def vrtdivspec(self, ugrid): def vrtdivspec(self, ugrid):
"""spatial data from spectral coefficients""" """
Compute vorticity and divergence from velocity field.
Parameters
-----------
ugrid : torch.Tensor
Velocity field in spatial coordinates
Returns
-------
torch.Tensor
Spectral coefficients of vorticity and divergence
"""
vrtdivspec = self.lap * self.radius * self.vsht(ugrid) vrtdivspec = self.lap * self.radius * self.vsht(ugrid)
return vrtdivspec return vrtdivspec
def getuv(self, vrtdivspec): def getuv(self, vrtdivspec):
""" """
compute wind vector from spectral coeffs of vorticity and divergence Compute wind vector from spectral coefficients of vorticity and divergence.
Parameters
-----------
vrtdivspec : torch.Tensor
Spectral coefficients of vorticity and divergence
Returns
-------
torch.Tensor
Wind vector field in spatial coordinates
""" """
return self.ivsht( self.invlap * vrtdivspec / self.radius) return self.ivsht( self.invlap * vrtdivspec / self.radius)
def gethuv(self, uspec): def gethuv(self, uspec):
""" """
compute wind vector from spectral coeffs of vorticity and divergence Compute height and wind vector from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Combined height and wind vector field
""" """
hgrid = self.spec2grid(uspec[:1]) hgrid = self.spec2grid(uspec[:1])
uvgrid = self.getuv(uspec[1:]) uvgrid = self.getuv(uspec[1:])
...@@ -146,7 +226,17 @@ class ShallowWaterSolver(nn.Module): ...@@ -146,7 +226,17 @@ class ShallowWaterSolver(nn.Module):
def potential_vorticity(self, uspec): def potential_vorticity(self, uspec):
""" """
Compute potential vorticity Compute potential vorticity from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Potential vorticity field
""" """
ugrid = self.spec2grid(uspec) ugrid = self.spec2grid(uspec)
pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0] pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0]
...@@ -154,7 +244,17 @@ class ShallowWaterSolver(nn.Module): ...@@ -154,7 +244,17 @@ class ShallowWaterSolver(nn.Module):
def dimensionless(self, uspec): def dimensionless(self, uspec):
""" """
Remove dimensions from variables Remove dimensions from variables for dimensionless analysis.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients with dimensions
Returns
-------
torch.Tensor
Dimensionless spectral coefficients
""" """
uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r # vorticity is measured in 1/s so we normalize using sqrt(g h) / r
...@@ -163,9 +263,18 @@ class ShallowWaterSolver(nn.Module): ...@@ -163,9 +263,18 @@ class ShallowWaterSolver(nn.Module):
def dudtspec(self, uspec): def dudtspec(self, uspec):
""" """
Compute time derivatives from solution represented in spectral coefficients Compute time derivatives from solution represented in spectral coefficients.
"""
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Time derivatives of spectral coefficients
"""
dudtspec = torch.zeros_like(uspec) dudtspec = torch.zeros_like(uspec)
# compute the derivatives - this should be incorporated into the solver: # compute the derivatives - this should be incorporated into the solver:
...@@ -191,10 +300,15 @@ class ShallowWaterSolver(nn.Module): ...@@ -191,10 +300,15 @@ class ShallowWaterSolver(nn.Module):
def galewsky_initial_condition(self): def galewsky_initial_condition(self):
""" """
Initializes non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440). Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations; [1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
Returns
-------
torch.Tensor
Initial spectral coefficients for the Galewsky test case
""" """
device = self.lap.device device = self.lap.device
...@@ -234,7 +348,17 @@ class ShallowWaterSolver(nn.Module): ...@@ -234,7 +348,17 @@ class ShallowWaterSolver(nn.Module):
def random_initial_condition(self, mach=0.1) -> torch.Tensor: def random_initial_condition(self, mach=0.1) -> torch.Tensor:
""" """
random initial condition on the sphere Generate random initial condition on the sphere.
Parameters
-----------
mach : float, optional
Mach number for scaling the random perturbations, by default 0.1
Returns
-------
torch.Tensor
Random initial spectral coefficients
""" """
device = self.lap.device device = self.lap.device
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64 ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
......
...@@ -66,13 +66,34 @@ class Stanford2D3DSDownloader: ...@@ -66,13 +66,34 @@ class Stanford2D3DSDownloader:
""" """
def __init__(self, base_url: str = DEFAULT_BASE_URL, local_dir: str = "data"): def __init__(self, base_url: str = DEFAULT_BASE_URL, local_dir: str = "data"):
"""
Initialize the Stanford 2D3DS dataset downloader.
Parameters
-----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
"""
self.base_url = base_url self.base_url = base_url
self.local_dir = local_dir self.local_dir = local_dir
os.makedirs(self.local_dir, exist_ok=True) os.makedirs(self.local_dir, exist_ok=True)
def _download_file(self, filename): def _download_file(self, filename):
"""
Download a single file with progress bar and resume capability.
Parameters
-----------
filename : str
Name of the file to download
Returns
-------
str
Local path to the downloaded file
"""
import requests import requests
from tqdm import tqdm from tqdm import tqdm
...@@ -106,6 +127,19 @@ class Stanford2D3DSDownloader: ...@@ -106,6 +127,19 @@ class Stanford2D3DSDownloader:
return local_path return local_path
def _extract_tar(self, tar_path): def _extract_tar(self, tar_path):
"""
Extract a tar file and return the extracted directory name.
Parameters
-----------
tar_path : str
Path to the tar file to extract
Returns
-------
str
Name of the extracted directory
"""
import tarfile import tarfile
with tarfile.open(tar_path) as tar: with tarfile.open(tar_path) as tar:
...@@ -116,7 +150,20 @@ class Stanford2D3DSDownloader: ...@@ -116,7 +150,20 @@ class Stanford2D3DSDownloader:
return extracted_dir return extracted_dir
def download_dataset(self, file_extracted_directory_pairs=DEFAULT_TAR_FILE_PAIRS): def download_dataset(self, file_extracted_directory_pairs=DEFAULT_TAR_FILE_PAIRS):
"""
Download and extract the complete dataset.
Parameters
-----------
file_extracted_directory_pairs : list, optional
List of (filename, extracted_folder_name) pairs, by default DEFAULT_TAR_FILE_PAIRS
Returns
-------
tuple
(data_folders, class_labels) where data_folders is a list of extracted directory names
and class_labels is the semantic label mapping
"""
import requests import requests
data_folders = [] data_folders = []
...@@ -133,6 +180,23 @@ class Stanford2D3DSDownloader: ...@@ -133,6 +180,23 @@ class Stanford2D3DSDownloader:
return data_folders, class_labels return data_folders, class_labels
def _rgb_to_id(self, img, class_labels_map, class_labels_indices): def _rgb_to_id(self, img, class_labels_map, class_labels_indices):
"""
Convert RGB image to class ID using color mapping.
Parameters
-----------
img : numpy.ndarray
RGB image array
class_labels_map : list
Mapping from color values to class labels
class_labels_indices : list
List of class label indices
Returns
-------
numpy.ndarray
Class ID array
"""
# Convert to int32 first to avoid overflow # Convert to int32 first to avoid overflow
r = img[..., 0].astype(np.int32) r = img[..., 0].astype(np.int32)
g = img[..., 1].astype(np.int32) g = img[..., 1].astype(np.int32)
...@@ -167,7 +231,35 @@ class Stanford2D3DSDownloader: ...@@ -167,7 +231,35 @@ class Stanford2D3DSDownloader:
downsampling_factor: int = 16, downsampling_factor: int = 16,
remove_alpha_channel: bool = True, remove_alpha_channel: bool = True,
): ):
"""
Convert the downloaded dataset to HDF5 format for efficient loading.
Parameters
-----------
data_folders : list
List of extracted data folder names
class_labels : list
List of semantic class labels
rgb_path : str, optional
Relative path to RGB images within each data folder, by default "pano/rgb"
semantic_path : str, optional
Relative path to semantic labels within each data folder, by default "pano/semantic"
depth_path : str, optional
Relative path to depth images within each data folder, by default "pano/depth"
output_filename : str, optional
Suffix for semantic label files, by default "semantic"
dataset_file : str, optional
Output HDF5 filename, by default "stanford_2d3ds_dataset.h5"
downsampling_factor : int, optional
Factor by which to downsample images, by default 16
remove_alpha_channel : bool, optional
Whether to remove alpha channel from RGB images, by default True
Returns
-------
str
Path to the created HDF5 dataset file
"""
converted_dataset_path = os.path.join(self.local_dir, dataset_file) converted_dataset_path = os.path.join(self.local_dir, dataset_file)
from PIL import Image from PIL import Image
......
...@@ -62,12 +62,23 @@ class FilterBasis(metaclass=abc.ABCMeta): ...@@ -62,12 +62,23 @@ class FilterBasis(metaclass=abc.ABCMeta):
self, self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
): ):
"""
Initialize the filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
self.kernel_shape = kernel_shape self.kernel_shape = kernel_shape
@property @property
@abc.abstractmethod @abc.abstractmethod
def kernel_size(self): def kernel_size(self):
"""
Abstract property that should return the size of the kernel.
Returns:
int: the kernel size
"""
raise NotImplementedError raise NotImplementedError
# @abc.abstractmethod # @abc.abstractmethod
...@@ -109,7 +120,12 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -109,7 +120,12 @@ class PiecewiseLinearFilterBasis(FilterBasis):
self, self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
): ):
"""
Initialize the piecewise linear filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, int): if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape] kernel_shape = [kernel_shape]
if len(kernel_shape) == 1: if len(kernel_shape) == 1:
...@@ -121,6 +137,12 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -121,6 +137,12 @@ class PiecewiseLinearFilterBasis(FilterBasis):
@property @property
def kernel_size(self): def kernel_size(self):
"""
Compute the kernel size for piecewise linear basis.
Returns:
int: the kernel size
"""
return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2 return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
...@@ -225,7 +247,12 @@ class MorletFilterBasis(FilterBasis): ...@@ -225,7 +247,12 @@ class MorletFilterBasis(FilterBasis):
self, self,
kernel_shape: Union[int, Tuple[int], Tuple[int, int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
): ):
"""
Initialize the Morlet filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, int): if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape, kernel_shape] kernel_shape = [kernel_shape, kernel_shape]
if len(kernel_shape) != 2: if len(kernel_shape) != 2:
...@@ -235,12 +262,38 @@ class MorletFilterBasis(FilterBasis): ...@@ -235,12 +262,38 @@ class MorletFilterBasis(FilterBasis):
@property @property
def kernel_size(self): def kernel_size(self):
"""
Compute the kernel size for Morlet basis.
Returns:
int: the kernel size
"""
return self.kernel_shape[0] * self.kernel_shape[1] return self.kernel_shape[0] * self.kernel_shape[1]
def gaussian_window(self, r: torch.Tensor, width: float = 1.0): def gaussian_window(self, r: torch.Tensor, width: float = 1.0):
"""
Compute Gaussian window function.
Parameters:
r: radial distance tensor
width: width parameter of the Gaussian
Returns:
torch.Tensor: Gaussian window values
"""
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2)) return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))
def hann_window(self, r: torch.Tensor, width: float = 1.0): def hann_window(self, r: torch.Tensor, width: float = 1.0):
"""
Compute Hann window function.
Parameters:
r: radial distance tensor
width: width parameter of the Hann window
Returns:
torch.Tensor: Hann window values
"""
return torch.cos(0.5 * torch.pi * r / width) ** 2 return torch.cos(0.5 * torch.pi * r / width) ** 2
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0): def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0):
...@@ -282,7 +335,12 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -282,7 +335,12 @@ class ZernikeFilterBasis(FilterBasis):
self, self,
kernel_shape: Union[int, Tuple[int]], kernel_shape: Union[int, Tuple[int]],
): ):
"""
Initialize the Zernike filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list): if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list):
kernel_shape = kernel_shape[0] kernel_shape = kernel_shape[0]
if not isinstance(kernel_shape, int): if not isinstance(kernel_shape, int):
...@@ -292,9 +350,26 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -292,9 +350,26 @@ class ZernikeFilterBasis(FilterBasis):
@property @property
def kernel_size(self): def kernel_size(self):
"""
Compute the kernel size for Zernike basis.
Returns:
int: the kernel size
"""
return (self.kernel_shape * (self.kernel_shape + 1)) // 2 return (self.kernel_shape * (self.kernel_shape + 1)) // 2
def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor): def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor):
"""
Compute radial Zernike polynomials.
Parameters:
r: radial distance tensor
n: principal quantum number
m: azimuthal quantum number
Returns:
torch.Tensor: radial Zernike polynomial values
"""
out = torch.zeros_like(r) out = torch.zeros_like(r)
bound = (n - m) // 2 + 1 bound = (n - m) // 2 + 1
max_bound = bound.max().item() max_bound = bound.max().item()
...@@ -307,6 +382,18 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -307,6 +382,18 @@ class ZernikeFilterBasis(FilterBasis):
return out return out
def zernikepoly(self, r: torch.Tensor, phi: torch.Tensor, n: torch.Tensor, l: torch.Tensor): def zernikepoly(self, r: torch.Tensor, phi: torch.Tensor, n: torch.Tensor, l: torch.Tensor):
"""
Compute Zernike polynomials.
Parameters:
r: radial distance tensor
phi: azimuthal angle tensor
n: principal quantum number
l: azimuthal quantum number
Returns:
torch.Tensor: Zernike polynomial values
"""
m = 2 * l - n m = 2 * l - n
return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi)) return torch.where(m < 0, self.zernikeradial(r, n, -m) * torch.sin(m * phi), self.zernikeradial(r, n, m) * torch.cos(m * phi))
......
...@@ -47,6 +47,14 @@ except ImportError as err: ...@@ -47,6 +47,14 @@ except ImportError as err:
def check_plotting_dependencies(): def check_plotting_dependencies():
"""
Check if required plotting dependencies (matplotlib and cartopy) are available.
Raises
------
ImportError
If matplotlib or cartopy is not installed
"""
if plt is None: if plt is None:
raise ImportError("matplotlib is required for plotting functions. Install it with 'pip install matplotlib'") raise ImportError("matplotlib is required for plotting functions. Install it with 'pip install matplotlib'")
if cartopy is None: if cartopy is None:
...@@ -58,6 +66,28 @@ def get_projection( ...@@ -58,6 +66,28 @@ def get_projection(
central_latitude=0, central_latitude=0,
central_longitude=0, central_longitude=0,
): ):
"""
Get a cartopy projection object for map plotting.
Parameters
-----------
projection : str
Projection type ("orthographic", "robinson", "platecarree", "mollweide")
central_latitude : float, optional
Central latitude for the projection, by default 0
central_longitude : float, optional
Central longitude for the projection, by default 0
Returns
-------
cartopy.crs.Projection
Cartopy projection object
Raises
------
ValueError
If projection type is not supported
"""
if projection == "orthographic": if projection == "orthographic":
proj = ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude) proj = ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude)
elif projection == "robinson": elif projection == "robinson":
...@@ -77,6 +107,40 @@ def plot_sphere( ...@@ -77,6 +107,40 @@ def plot_sphere(
): ):
""" """
Plots a function defined on the sphere using pcolormesh Plots a function defined on the sphere using pcolormesh
Parameters
-----------
data : numpy.ndarray or torch.Tensor
Data to plot with shape (nlat, nlon)
fig : matplotlib.figure.Figure, optional
Figure to plot on, by default None (creates new figure)
projection : str, optional
Map projection type, by default "robinson"
cmap : str, optional
Colormap name, by default "RdBu"
title : str, optional
Plot title, by default None
colorbar : bool, optional
Whether to add a colorbar, by default False
coastlines : bool, optional
Whether to add coastlines, by default False
gridlines : bool, optional
Whether to add gridlines, by default False
central_latitude : float, optional
Central latitude for projection, by default 0
central_longitude : float, optional
Central longitude for projection, by default 0
lon : numpy.ndarray, optional
Longitude coordinates, by default None (auto-generated)
lat : numpy.ndarray, optional
Latitude coordinates, by default None (auto-generated)
**kwargs
Additional arguments passed to pcolormesh
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
""" """
# make sure cartopy exist # make sure cartopy exist
...@@ -126,6 +190,28 @@ def plot_sphere( ...@@ -126,6 +190,28 @@ def plot_sphere(
def imshow_sphere(data, fig=None, projection="robinson", title=None, central_latitude=0, central_longitude=0, **kwargs): def imshow_sphere(data, fig=None, projection="robinson", title=None, central_latitude=0, central_longitude=0, **kwargs):
""" """
Displays an image on the sphere Displays an image on the sphere
Parameters
-----------
data : numpy.ndarray or torch.Tensor
Data to display with shape (nlat, nlon)
fig : matplotlib.figure.Figure, optional
Figure to plot on, by default None (creates new figure)
projection : str, optional
Map projection type, by default "robinson"
title : str, optional
Plot title, by default None
central_latitude : float, optional
Central latitude for projection, by default 0
central_longitude : float, optional
Central longitude for projection, by default 0
**kwargs
Additional arguments passed to imshow
Returns
-------
matplotlib.image.AxesImage
The displayed image object
""" """
# make sure cartopy exist # make sure cartopy exist
......
...@@ -37,6 +37,32 @@ import torch ...@@ -37,6 +37,32 @@ import torch
def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0, def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0,
periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]: periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Precompute grid points and weights for various quadrature rules.
Parameters
-----------
n : int
Number of grid points
grid : str, optional
Grid type ("equidistant", "legendre-gauss", "lobatto", "equiangular"), by default "equidistant"
a : float, optional
Lower bound of interval, by default 0.0
b : float, optional
Upper bound of interval, by default 1.0
periodic : bool, optional
Whether the grid is periodic (only for equidistant), by default False
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Grid points and weights
Raises
------
ValueError
If periodic is True for non-equidistant grids or unknown grid type
"""
if (grid != "equidistant") and periodic: if (grid != "equidistant") and periodic:
raise ValueError(f"Periodic grid is only supported on equidistant grids.") raise ValueError(f"Periodic grid is only supported on equidistant grids.")
......
...@@ -40,6 +40,30 @@ from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longit ...@@ -40,6 +40,30 @@ from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longit
class ResampleS2(nn.Module): class ResampleS2(nn.Module):
"""
Resampling module for signals on the 2-sphere.
This module provides functionality to resample spherical signals between different
grid resolutions and grid types using bilinear interpolation.
Parameters
-----------
nlat_in : int
Number of latitude points in the input grid
nlon_in : int
Number of longitude points in the input grid
nlat_out : int
Number of latitude points in the output grid
nlon_out : int
Number of longitude points in the output grid
grid_in : str, optional
Input grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
grid_out : str, optional
Output grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
mode : str, optional
Interpolation mode ("bilinear", "bilinear-spherical"), by default "bilinear"
"""
def __init__( def __init__(
self, self,
nlat_in: int, nlat_in: int,
...@@ -119,6 +143,19 @@ class ResampleS2(nn.Module): ...@@ -119,6 +143,19 @@ class ResampleS2(nn.Module):
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor): def _upscale_longitudes(self, x: torch.Tensor):
"""
Interpolate the input tensor along the longitude dimension.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Interpolated tensor along longitude dimension
"""
# do the interpolation in precision of x # do the interpolation in precision of x
lwgt = self.lon_weights.to(x.dtype) lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear": if self.mode == "bilinear":
...@@ -133,6 +170,19 @@ class ResampleS2(nn.Module): ...@@ -133,6 +170,19 @@ class ResampleS2(nn.Module):
return x return x
def _expand_poles(self, x: torch.Tensor): def _expand_poles(self, x: torch.Tensor):
"""
Expand the input tensor to include pole points for interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Expanded tensor with pole points added
"""
x_north = x[..., 0, :].mean(dim=-1, keepdims=True) x_north = x[..., 0, :].mean(dim=-1, keepdims=True)
x_south = x[..., -1, :].mean(dim=-1, keepdims=True) x_south = x[..., -1, :].mean(dim=-1, keepdims=True)
x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant') x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant')
...@@ -142,6 +192,19 @@ class ResampleS2(nn.Module): ...@@ -142,6 +192,19 @@ class ResampleS2(nn.Module):
return x return x
def _upscale_latitudes(self, x: torch.Tensor): def _upscale_latitudes(self, x: torch.Tensor):
"""
Interpolate the input tensor along the latitude dimension.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Interpolated tensor along latitude dimension
"""
# do the interpolation in precision of x # do the interpolation in precision of x
lwgt = self.lat_weights.to(x.dtype) lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear": if self.mode == "bilinear":
...@@ -156,6 +219,19 @@ class ResampleS2(nn.Module): ...@@ -156,6 +219,19 @@ class ResampleS2(nn.Module):
return x return x
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
"""
Forward pass of the resampling module.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat_in, nlon_in)
Returns
-------
torch.Tensor
Resampled tensor with shape (..., nlat_out, nlon_out)
"""
if self.skip_resampling: if self.skip_resampling:
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment