Commit e4879676 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Added docstrings to many methods

parent b5c410c0
......@@ -41,6 +41,24 @@ from functools import partial
class OverlapPatchMerging(nn.Module):
"""
OverlapPatchMerging layer for merging patches.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
kernel_shape : tuple
Kernel shape for convolution
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(
self,
in_shape=(721, 1440),
......@@ -72,11 +90,32 @@ class OverlapPatchMerging(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""
Forward pass through the OverlapPatchMerging layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after patch merging
"""
x = self.conv(x)
# permute
......@@ -88,6 +127,30 @@ class OverlapPatchMerging(nn.Module):
class MixFFN(nn.Module):
"""
MixFFN module combining MLP and depthwise convolution.
Parameters
-----------
shape : tuple
Input shape (height, width)
inout_channels : int
Number of input/output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP layers, by default True
kernel_shape : tuple, optional
Kernel shape for depthwise convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : callable, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
drop_path : float, optional
Drop path rate, by default 0.0
"""
def __init__(
self,
shape,
......@@ -124,6 +187,14 @@ class MixFFN(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
......@@ -133,7 +204,19 @@ class MixFFN(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the MixFFN module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after processing
"""
residual = x
# norm
......@@ -162,6 +245,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
......@@ -169,6 +263,19 @@ class GlobalAttention(nn.Module):
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x):
"""
Forward pass through the GlobalAttention module.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (B, C, H, W)
Returns
-------
torch.Tensor
Output tensor of shape (B, C, H, W)
"""
# x: B, C, H, W
B, H, W, C = x.shape
# flatten spatial dims
......@@ -181,6 +288,30 @@ class GlobalAttention(nn.Module):
class AttentionWrapper(nn.Module):
"""
Wrapper for different attention mechanisms.
Parameters
-----------
channels : int
Number of channels
shape : tuple
Input shape (height, width)
heads : int
Number of attention heads
pre_norm : bool, optional
Whether to apply normalization before attention, by default False
attention_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(self, channels, shape, heads, pre_norm=False, attention_drop_rate=0.0, drop_path=0.0, attention_mode="neighborhood", kernel_shape=(7, 7), bias=True):
super().__init__()
......@@ -203,11 +334,32 @@ class AttentionWrapper(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the AttentionWrapper.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with residual connection
"""
residual = x
x = x.permute(0, 2, 3, 1)
if self.norm is not None:
......
......@@ -57,11 +57,46 @@ class Encoder(nn.Module):
self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_shape, bias=bias, stride=(stride_h, stride_w), padding=(pad_h, pad_w), groups=groups)
def forward(self, x):
"""
Forward pass through the Encoder layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after encoding
"""
x = self.conv(x)
return x
class Decoder(nn.Module):
"""
Decoder module for upsampling and feature processing.
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (721, 1440)
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
groups : int, optional
Number of groups for convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsampling_method : str, optional
Upsampling method ("conv", "pixel_shuffle"), by default "conv"
"""
def __init__(self, in_shape=(480, 960), out_shape=(721, 1440), in_chans=2, out_chans=2, kernel_shape=(3, 3), groups=1, bias=False, upsampling_method="conv"):
super().__init__()
self.out_shape = out_shape
......@@ -87,6 +122,19 @@ class Decoder(nn.Module):
raise ValueError(f"Unknown upsampling method {upsampling_method}")
def forward(self, x):
"""
Forward pass through the Decoder layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after decoding and upsampling
"""
x = self.upsample(x)
return x
......@@ -97,6 +145,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
......@@ -104,6 +163,19 @@ class GlobalAttention(nn.Module):
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x):
"""
Forward pass through the GlobalAttention module.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (B, C, H, W)
Returns
-------
torch.Tensor
Output tensor of shape (B, C, H, W)
"""
# x: B, C, H, W
B, H, W, C = x.shape
# flatten spatial dims
......@@ -118,8 +190,36 @@ class GlobalAttention(nn.Module):
class AttentionBlock(nn.Module):
"""
Neighborhood attention block based on Natten.
"""
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (480, 960)
chans : int, optional
Number of channels, by default 2
num_heads : int, optional
Number of attention heads, by default 1
mlp_ratio : float, optional
Ratio of MLP hidden dim to input dim, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : callable, optional
Activation function, by default nn.GELU
norm_layer : str, optional
Normalization layer type, by default "none"
use_mlp : bool, optional
Whether to use MLP, by default True
bias : bool, optional
Whether to use bias, by default True
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
"""
def __init__(
self,
in_shape=(480, 960),
......@@ -186,7 +286,19 @@ class AttentionBlock(nn.Module):
self.skip1 = nn.Identity()
def forward(self, x):
"""
Forward pass through the AttentionBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with residual connections
"""
residual = x
x = self.norm0(x)
......
......@@ -140,12 +140,33 @@ class DownsamplingBlock(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the DownsamplingBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after downsampling
"""
# skip connection
residual = x
if hasattr(self, "transform_skip"):
......@@ -166,6 +187,36 @@ class DownsamplingBlock(nn.Module):
class UpsamplingBlock(nn.Module):
"""
Upsampling block for UNet architecture.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__(
self,
in_shape,
......@@ -274,12 +325,33 @@ class UpsamplingBlock(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the UpsamplingBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after upsampling
"""
# skip connection
residual = x
if hasattr(self, "transform_skip"):
......@@ -304,6 +376,7 @@ class UNet(nn.Module):
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
Kernel shape for convolutions
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
......@@ -336,11 +409,12 @@ class UNet(nn.Module):
... scale_factor=4,
... in_chans=2,
... num_classes=2,
... embed_dims=[64, 128, 256, 512],)
... embed_dims=[16, 32, 64, 128],
... depths=[2, 2, 2, 2],
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
img_shape=(128, 256),
......@@ -440,6 +514,14 @@ class UNet(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
......@@ -450,7 +532,19 @@ class UNet(nn.Module):
def forward(self, x):
"""
Forward pass through the UNet model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, num_classes, height, width)
"""
# encoder:
features = []
feat = x
......
......@@ -63,13 +63,37 @@ import wandb
# helper routine for counting number of paramerters in model
def count_parameters(model):
"""
Count the number of trainable parameters in a model.
Parameters
-----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# convenience function for logging weights and gradients
def log_weights_and_grads(exp_dir, model, iters=1):
"""
Helper routine intended for debugging purposes
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
-----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
log_path = os.path.join(exp_dir, "weights_and_grads")
if not os.path.isdir(log_path):
......@@ -97,7 +121,36 @@ def validate_model(
logging=True,
device=torch.device("cpu"),
):
"""
Validate a model on a dataset and compute metrics.
Parameters
-----------
model : torch.nn.Module
Model to validate
dataloader : torch.utils.data.DataLoader
DataLoader for validation data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving validation outputs
normalization_in : callable, optional
Normalization function to apply to inputs, by default None
normalization_out : callable, optional
Normalization function to apply to targets, by default None
logging : bool, optional
Whether to save validation plots, by default True
device : torch.device, optional
Device to run validation on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics) where losses is a tensor of per-sample losses
and metrics is a dict of per-sample metric values
"""
model.eval()
num_examples = len(dataloader)
......@@ -190,7 +243,50 @@ def train_model(
logging=True,
device=torch.device("cpu"),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
train_dataloader : torch.utils.data.DataLoader
DataLoader for training data
train_sampler : torch.utils.data.Sampler
Sampler for training data
test_dataloader : torch.utils.data.DataLoader
DataLoader for test data
test_sampler : torch.utils.data.Sampler
Sampler for test data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
normalization_in : callable, optional
Normalization function to apply to inputs, by default None
normalization_out : callable, optional
Normalization function to apply to targets, by default None
augmentation : bool, optional
Whether to apply data augmentation, by default False
nepochs : int, optional
Number of training epochs, by default 20
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
exp_dir : str, optional
Experiment directory for logging, by default None
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start = time.time()
# set AMP type
......
......@@ -39,6 +39,76 @@ from baseline_models import Transformer, UNet, Segformer
from torch_harmonics.examples.models import SphericalFourierNeuralOperator, LocalSphericalNeuralOperator, SphericalTransformer, SphericalUNet, SphericalSegformer
def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_prediction=False, drop_path_rate=0., grid="equiangular"):
"""
Get a registry of baseline models for spherical and planar neural networks.
This function returns a dictionary containing pre-configured model architectures
for various tasks including spherical Fourier neural operators (SFNO), local
spherical neural operators (LSNO), spherical transformers, U-Nets, and Segformers.
Each model is configured with specific hyperparameters optimized for different
computational budgets and performance requirements.
Parameters
-----------
img_size : tuple, optional
Input image size as (height, width), by default (128, 256)
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
residual_prediction : bool, optional
Whether to use residual prediction (add input to output), by default False
drop_path_rate : float, optional
Dropout path rate for regularization, by default 0.0
grid : str, optional
Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular"
Returns
-------
dict
Dictionary mapping model names to partial functions that can be called
to instantiate the corresponding model with the specified parameters.
Available models include:
**Spherical Models:**
- sfno_sc2_layers4_e32: Spherical Fourier Neural Operator (small)
- lsno_sc2_layers4_e32: Local Spherical Neural Operator (small)
- s2unet_sc2_layers4_e128: Spherical U-Net (medium)
- s2transformer_sc2_layers4_e128: Spherical Transformer (global attention, medium)
- s2transformer_sc2_layers4_e256: Spherical Transformer (global attention, large)
- s2ntransformer_sc2_layers4_e128: Spherical Transformer (neighborhood attention, medium)
- s2ntransformer_sc2_layers4_e256: Spherical Transformer (neighborhood attention, large)
- s2segformer_sc2_layers4_e128: Spherical Segformer (global attention, medium)
- s2segformer_sc2_layers4_e256: Spherical Segformer (global attention, large)
- s2nsegformer_sc2_layers4_e128: Spherical Segformer (neighborhood attention, medium)
- s2nsegformer_sc2_layers4_e256: Spherical Segformer (neighborhood attention, large)
**Planar Models:**
- transformer_sc2_layers4_e128: Planar Transformer (global attention, medium)
- transformer_sc2_layers4_e256: Planar Transformer (global attention, large)
- ntransformer_sc2_layers4_e128: Planar Transformer (neighborhood attention, medium)
- ntransformer_sc2_layers4_e256: Planar Transformer (neighborhood attention, large)
- segformer_sc2_layers4_e128: Planar Segformer (global attention, medium)
- segformer_sc2_layers4_e256: Planar Segformer (global attention, large)
- nsegformer_sc2_layers4_e128: Planar Segformer (neighborhood attention, medium)
- nsegformer_sc2_layers4_e256: Planar Segformer (neighborhood attention, large)
- vit_sc2_layers4_e128: Vision Transformer variant (medium)
Examples
--------
>>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1)
>>> model = model_registry['sfno_sc2_layers4_e32']()
>>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Notes
-----
- Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim}
- 'sc2' indicates scale factor of 2 (downsampling by 2)
- 'e32', 'e128', 'e256' indicate embedding dimensions
- 'n' prefix indicates neighborhood attention instead of global attention
- All models use GELU activation and instance normalization by default
"""
# prepare dicts containing models and corresponding metrics
model_registry = dict(
......
......@@ -68,13 +68,37 @@ import wandb
# helper routine for counting number of paramerters in model
def count_parameters(model):
"""
Count the number of trainable parameters in a model.
Parameters
-----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# convenience function for logging weights and gradients
def log_weights_and_grads(exp_dir, model, iters=1):
"""
Helper routine intended for debugging purposes
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
-----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
log_path = os.path.join(exp_dir, "weights_and_grads")
if not os.path.isdir(log_path):
......@@ -92,7 +116,34 @@ def log_weights_and_grads(exp_dir, model, iters=1):
# rolls out the FNO and compares to the classical solver
def validate_model(model, dataloader, loss_fn, metrics_fns, path_root, normalization=None, logging=True, device=torch.device("cpu")):
"""
Validate a model on a dataset and compute metrics.
Parameters
-----------
model : torch.nn.Module
Model to validate
dataloader : torch.utils.data.DataLoader
DataLoader for validation data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving validation outputs
normalization : callable, optional
Normalization function to apply to inputs, by default None
logging : bool, optional
Whether to save validation plots, by default True
device : torch.device, optional
Device to run validation on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics) where losses is a tensor of per-sample losses
and metrics is a dict of per-sample metric values
"""
model.eval()
num_examples = len(dataloader)
......@@ -178,7 +229,50 @@ def train_model(
logging=True,
device=torch.device("cpu"),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
train_dataloader : torch.utils.data.DataLoader
DataLoader for training data
train_sampler : torch.utils.data.Sampler
Sampler for training data
test_dataloader : torch.utils.data.DataLoader
DataLoader for test data
test_sampler : torch.utils.data.Sampler
Sampler for test data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
max_grad_norm : float, optional
Maximum gradient norm for clipping, by default 0.0
normalization : callable, optional
Normalization function to apply to inputs, by default None
augmentation : callable, optional
Augmentation function to apply to inputs, by default None
nepochs : int, optional
Number of training epochs, by default 20
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
exp_dir : str, optional
Experiment directory for logging, by default None
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start = time.time()
# set AMP type
......
......@@ -63,13 +63,35 @@ except:
# helper routine for counting number of paramerters in model
def count_parameters(model):
"""
Count the number of trainable parameters in a model.
Parameters
-----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1):
"""
Helper routine intended for debugging purposes
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
-----------
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads")
......@@ -97,7 +119,40 @@ def autoregressive_inference(
nics=50,
device=torch.device("cpu"),
):
"""
Perform autoregressive inference with a trained model and compare to classical solver.
Parameters
-----------
model : torch.nn.Module
Trained model to evaluate
dataset : torch.utils.data.Dataset
Dataset containing solver and normalization parameters
loss_fn : callable
Loss function for evaluation
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving inference outputs
nsteps : int
Number of solver steps per autoregressive step
autoreg_steps : int, optional
Number of autoregressive steps, by default 10
nskip : int, optional
Skip interval for plotting, by default 1
plot_channel : int, optional
Channel to plot, by default 0
nics : int, optional
Number of initial conditions to test, by default 50
device : torch.device, optional
Device to run inference on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics, model_times, solver_times) where losses and metrics are tensors
of per-sample values, and model_times and solver_times are timing information
"""
model.eval()
# make output
......@@ -238,7 +293,42 @@ def train_model(
logging=True,
device=torch.device("cpu"),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
dataloader : torch.utils.data.DataLoader
DataLoader for training data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
nepochs : int, optional
Number of training epochs, by default 20
nfuture : int, optional
Number of future steps to predict, by default 0
num_examples : int, optional
Number of examples per epoch, by default 256
num_valid : int, optional
Number of validation examples, by default 8
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start = time.time()
# set AMP type
......
......@@ -54,7 +54,22 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
warnings.warn(f"building custom extensions skipped: {e}")
def get_compile_args(module_name):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
"""
Get compilation arguments based on environment variables.
If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build.
If TORCH_HARMONICS_PROFILE=1 is set, it will include profiling flags.
Parameters
-----------
module_name : str
Name of the module being compiled
Returns
-------
dict
Dictionary containing compilation flags for 'cxx' and 'nvcc' compilers
"""
debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1'
profile_mode = os.environ.get('TORCH_HARMONICS_PROFILE', '0') == '1'
......@@ -77,7 +92,15 @@ def get_compile_args(module_name):
}
def get_ext_modules():
"""
Get list of extension modules to compile.
Returns
-------
tuple
(ext_modules, cmdclass) where ext_modules is a list of extension modules
and cmdclass is a dictionary of build commands
"""
ext_modules = []
cmdclass = {}
......
......@@ -38,6 +38,17 @@ import torch
class TestCacheConsistency(unittest.TestCase):
def test_consistency(self, verbose=False):
"""
Test that cached values are not modified externally.
This test verifies that the LRU cache decorator properly handles
deep copying to prevent unintended modifications to cached objects.
Parameters
-----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
if verbose:
print("Testing that cache values does not get modified externally")
from torch_harmonics.legendre import _precompute_legpoly
......
......@@ -62,11 +62,31 @@ def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nl
class _DiscoS2ContractionCuda(torch.autograd.Function):
r"""
CUDA implementation of the discrete-continuous convolution contraction on the sphere.
This class provides the forward and backward passes for efficient GPU computation
of the S2 convolution operation using custom CUDA kernels.
"""
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int):
r"""
Forward pass for CUDA S2 convolution contraction.
Parameters:
x: input tensor
roff_idx: row offset indices for sparse computation
ker_idx: kernel indices
row_idx: row indices for sparse computation
col_idx: column indices for sparse computation
vals: values for sparse computation
kernel_size: size of the kernel
nlat_out: number of output latitude points
nlon_out: number of output longitude points
"""
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2]
......@@ -81,6 +101,15 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
r"""
Backward pass for CUDA S2 convolution contraction.
Parameters:
grad_output: gradient of the output
Returns:
gradient of the input
"""
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous()
......@@ -92,11 +121,31 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
r"""
CUDA implementation of the transpose discrete-continuous convolution contraction on the sphere.
This class provides the forward and backward passes for efficient GPU computation
of the transpose S2 convolution operation using custom CUDA kernels.
"""
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int):
r"""
Forward pass for CUDA transpose S2 convolution contraction.
Parameters:
x: input tensor
roff_idx: row offset indices for sparse computation
ker_idx: kernel indices
row_idx: row indices for sparse computation
col_idx: column indices for sparse computation
vals: values for sparse computation
kernel_size: size of the kernel
nlat_out: number of output latitude points
nlon_out: number of output longitude points
"""
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2]
......@@ -111,6 +160,15 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
r"""
Backward pass for CUDA transpose S2 convolution contraction.
Parameters:
grad_output: gradient of the output
Returns:
gradient of the input
"""
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous()
......
......@@ -50,7 +50,40 @@ except ImportError as err:
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
"""
Forward pass implementation of neighborhood attention on the sphere (S2).
This function computes the neighborhood attention operation using sparse tensor
operations. It implements the attention mechanism with softmax normalization
and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi) where B is batch size, C is channels,
Hi is input height (latitude), Wi is input width (longitude)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo) where Ho is output height, Wo is output width
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor with shape (B, C, Ho, Wo) after neighborhood attention computation
"""
# prepare result tensor
y = torch.zeros_like(qy)
......@@ -102,6 +135,41 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
"""
Backward pass implementation for value gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the value tensor (vx).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the value tensor with shape (B, C, Hi, Wi)
"""
# shapes:
# input
......@@ -170,6 +238,41 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
"""
Backward pass implementation for key gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the key tensor (kx).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the key tensor with shape (B, C, Hi, Wi)
"""
# shapes:
# input
......@@ -251,6 +354,41 @@ def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor,
def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int):
"""
Backward pass implementation for query gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the query tensor (qy).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the query tensor with shape (B, C, Ho, Wo)
"""
# shapes:
# input
......@@ -321,6 +459,11 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
return dqy
class _NeighborhoodAttentionS2(torch.autograd.Function):
r"""
CPU implementation of neighborhood attention on the sphere (S2).
This class provides the forward and backward passes for efficient CPU computation
of neighborhood attention operations using sparse tensor operations.
"""
@staticmethod
@custom_fwd(device_type="cpu")
......@@ -329,7 +472,27 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
r"""
Forward pass for CPU neighborhood attention on S2.
Parameters:
k: key tensor
v: value tensor
q: query tensor
wk: key weight tensor
wv: value weight tensor
wq: query weight tensor
bk: key bias tensor (optional)
bv: value bias tensor (optional)
bq: query bias tensor (optional)
quad_weights: quadrature weights for spherical integration
col_idx: column indices for sparse computation
row_off: row offsets for sparse computation
nh: number of attention heads
nlon_in: number of input longitude points
nlat_out: number of output latitude points
nlon_out: number of output longitude points
"""
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh
ctx.nlon_in = nlon_in
......@@ -364,6 +527,15 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
@staticmethod
@custom_bwd(device_type="cpu")
def backward(ctx, grad_output):
r"""
Backward pass for CPU neighborhood attention on S2.
Parameters:
grad_output: gradient of the output
Returns:
gradients for all input tensors and parameters
"""
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh
nlon_in = ctx.nlon_in
......@@ -443,13 +615,63 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
"""
Torch implementation of neighborhood attention on the sphere (S2).
This function provides a wrapper around the CPU autograd function for
neighborhood attention operations using sparse tensor computations.
Parameters
-----------
k : torch.Tensor
Key tensor
v : torch.Tensor
Value tensor
q : torch.Tensor
Query tensor
wk : torch.Tensor
Key weight tensor
wv : torch.Tensor
Value weight tensor
wq : torch.Tensor
Query weight tensor
bk : torch.Tensor or None
Key bias tensor (optional)
bv : torch.Tensor or None
Value bias tensor (optional)
bq : torch.Tensor or None
Query bias tensor (optional)
quad_weights : torch.Tensor
Quadrature weights for spherical integration
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nh : int
Number of attention heads
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor after neighborhood attention computation
"""
return _NeighborhoodAttentionS2.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off,
nh, nlon_in, nlat_out, nlon_out)
class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
r"""
CUDA implementation of neighborhood attention on the sphere (S2).
This class provides the forward and backward passes for efficient GPU computation
of neighborhood attention operations using custom CUDA kernels.
"""
@staticmethod
@custom_fwd(device_type="cuda")
......@@ -458,7 +680,28 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
r"""
Forward pass for CUDA neighborhood attention on S2.
Parameters:
k: key tensor
v: value tensor
q: query tensor
wk: key weight tensor
wv: value weight tensor
wq: query weight tensor
bk: key bias tensor (optional)
bv: value bias tensor (optional)
bq: query bias tensor (optional)
quad_weights: quadrature weights for spherical integration
col_idx: column indices for sparse computation
row_off: row offsets for sparse computation
max_psi_nnz: maximum number of non-zero elements in sparse tensor
nh: number of attention heads
nlon_in: number of input longitude points
nlat_out: number of output latitude points
nlon_out: number of output longitude points
"""
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh
ctx.max_psi_nnz = max_psi_nnz
......@@ -499,6 +742,15 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
r"""
Backward pass for CUDA neighborhood attention on S2.
Parameters:
grad_output: gradient of the output
Returns:
gradients for all input tensors and parameters
"""
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh
max_psi_nnz = ctx.max_psi_nnz
......@@ -584,7 +836,54 @@ def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.T
bq: Union[torch.Tensor, None], quad_weights: torch.Tensor,
col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
"""
CUDA implementation of neighborhood attention on the sphere (S2).
This function provides a wrapper around the CUDA autograd function for
neighborhood attention operations using custom CUDA kernels for efficient GPU computation.
Parameters
-----------
k : torch.Tensor
Key tensor
v : torch.Tensor
Value tensor
q : torch.Tensor
Query tensor
wk : torch.Tensor
Key weight tensor
wv : torch.Tensor
Value weight tensor
wq : torch.Tensor
Query weight tensor
bk : torch.Tensor or None
Key bias tensor (optional)
bv : torch.Tensor or None
Value bias tensor (optional)
bq : torch.Tensor or None
Query bias tensor (optional)
quad_weights : torch.Tensor
Quadrature weights for spherical integration
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
max_psi_nnz : int
Maximum number of non-zero elements in sparse tensor
nh : int
Number of attention heads
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor after neighborhood attention computation
"""
return _NeighborhoodAttentionS2Cuda.apply(k, v, q, wk, wv, wq, bk, bv, bq,
quad_weights, col_idx, row_off, max_psi_nnz,
nh, nlon_in, nlat_out, nlon_out)
......@@ -35,6 +35,32 @@ from copy import deepcopy
# copying LRU cache decorator a la:
# https://stackoverflow.com/questions/54909357/how-to-get-functools-lru-cache-to-return-new-instances
def lru_cache(maxsize=20, typed=False, copy=False):
"""
Least Recently Used (LRU) cache decorator with optional deep copying.
This is a wrapper around functools.lru_cache that adds the ability to return
deep copies of cached results to prevent unintended modifications to cached objects.
Parameters
-----------
maxsize : int, optional
Maximum number of items to cache, by default 20
typed : bool, optional
Whether to cache different types separately, by default False
copy : bool, optional
Whether to return deep copies of cached results, by default False
Returns
-------
function
Decorated function with LRU caching
Example
-------
>>> @lru_cache(maxsize=10, copy=True)
... def expensive_function(x):
... return [x, x*2, x*3]
"""
def decorator(f):
cached_func = functools.lru_cache(maxsize=maxsize, typed=typed)(f)
def wrapper(*args, **kwargs):
......
......@@ -43,6 +43,32 @@ from torch_harmonics.distributed import compute_split_shapes
class DistributedResampleS2(nn.Module):
r"""
Distributed resampling module for spherical data on the 2-sphere.
This module performs distributed resampling of spherical data across multiple processes,
supporting both upscaling and downscaling operations. The data is distributed across
polar and azimuthal directions, and the module handles the necessary communication
and interpolation operations.
Parameters
-----------
nlat_in : int
Number of input latitude points
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
mode : str, optional
Interpolation mode ("bilinear" or "bilinear-spherical"), by default "bilinear"
"""
def __init__(
self,
nlat_in: int,
......@@ -133,6 +159,19 @@ class DistributedResampleS2(nn.Module):
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor):
"""
Upscale the longitude dimension using interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Upscaled tensor in the longitude dimension
"""
# do the interpolation
lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear":
......@@ -147,6 +186,19 @@ class DistributedResampleS2(nn.Module):
return x
def _expand_poles(self, x: torch.Tensor):
"""
Expand the data to include pole values for interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Tensor with expanded pole values
"""
x_north = x[..., 0, :].sum(dim=-1, keepdims=True)
x_south = x[..., -1, :].sum(dim=-1, keepdims=True)
x_count = torch.tensor([x.shape[-1]], dtype=torch.long, device=x.device, requires_grad=False)
......@@ -169,6 +221,19 @@ class DistributedResampleS2(nn.Module):
return x
def _upscale_latitudes(self, x: torch.Tensor):
"""
Upscale the latitude dimension using interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Upscaled tensor in the latitude dimension
"""
# do the interpolation
lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear":
......@@ -183,6 +248,19 @@ class DistributedResampleS2(nn.Module):
return x
def forward(self, x: torch.Tensor):
"""
Forward pass for distributed resampling.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Resampled tensor with shape (batch, channels, nlat_out, nlon_out)
"""
if self.skip_resampling:
return x
......
......@@ -95,10 +95,23 @@ def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False)
class distributed_transpose_azimuth(torch.autograd.Function):
r"""
Distributed transpose operation for azimuthal dimension.
This class provides the forward and backward passes for distributed
tensor transposition along the azimuthal dimension.
"""
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x, dims, dim1_split_sizes):
r"""
Forward pass for distributed azimuthal transpose.
Parameters:
x: input tensor
dims: dimensions to transpose
dim1_split_sizes: split sizes for dimension 1
"""
# WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
x = torch.cat(xlist, dim=dims[1])
......@@ -110,6 +123,15 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, go):
r"""
Backward pass for distributed azimuthal transpose.
Parameters:
go: gradient of the output
Returns:
gradient of the input
"""
dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
......@@ -120,10 +142,23 @@ class distributed_transpose_azimuth(torch.autograd.Function):
class distributed_transpose_polar(torch.autograd.Function):
r"""
Distributed transpose operation for polar dimension.
This class provides the forward and backward passes for distributed
tensor transposition along the polar dimension.
"""
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x, dim, dim1_split_sizes):
r"""
Forward pass for distributed polar transpose.
Parameters:
x: input tensor
dim: dimensions to transpose
dim1_split_sizes: split sizes for dimension 1
"""
# WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
x = torch.cat(xlist, dim=dim[1])
......@@ -134,6 +169,15 @@ class distributed_transpose_polar(torch.autograd.Function):
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, go):
r"""
Backward pass for distributed polar transpose.
Parameters:
go: gradient of the output
Returns:
gradient of the input
"""
dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
......@@ -244,7 +288,11 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
class _CopyToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
r"""
Copy tensor to polar region for distributed computation.
This class provides the forward and backward passes for copying
tensors to the polar region in distributed settings.
"""
@staticmethod
def symbolic(graph, input_):
......@@ -253,11 +301,29 @@ class _CopyToPolarRegion(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_):
r"""
Forward pass for copying to polar region.
Parameters:
input_: input tensor
Returns:
input tensor (no-op in forward pass)
"""
return input_
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
r"""
Backward pass for copying to polar region.
Parameters:
grad_output: gradient of the output
Returns:
gradient of the input
"""
if is_distributed_polar():
return _reduce(grad_output, group=polar_group())
else:
......@@ -265,7 +331,11 @@ class _CopyToPolarRegion(torch.autograd.Function):
class _CopyToAzimuthRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
r"""
Copy tensor to azimuth region for distributed computation.
This class provides the forward and backward passes for copying
tensors to the azimuth region in distributed settings.
"""
@staticmethod
def symbolic(graph, input_):
......@@ -274,11 +344,29 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_):
r"""
Forward pass for copying to azimuth region.
Parameters:
input_: input tensor
Returns:
input tensor (no-op in forward pass)
"""
return input_
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
r"""
Backward pass for copying to azimuth region.
Parameters:
grad_output: gradient of the output
Returns:
gradient of the input
"""
if is_distributed_azimuth():
return _reduce(grad_output, group=azimuth_group())
else:
......@@ -286,7 +374,11 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
class _ScatterToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
r"""
Scatter tensor to polar region for distributed computation.
This class provides the forward and backward passes for scattering
tensors to the polar region in distributed settings.
"""
@staticmethod
def symbolic(graph, input_, dim_):
......
......@@ -40,6 +40,27 @@ from torch_harmonics.quadrature import _precompute_latitudes
def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False, normalized: bool = True) -> torch.Tensor:
"""
Get quadrature weights for spherical integration.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str
Grid type ("equiangular", "legendre-gauss", "lobatto")
tile : bool, optional
Whether to tile weights across longitude dimension, by default False
normalized : bool, optional
Whether to normalize weights to sum to 1, by default True
Returns
-------
torch.Tensor
Quadrature weights tensor
"""
# area weights
_, q = _precompute_latitudes(nlat=nlat, grid=grid)
q = q.reshape(-1, 1) * 2 * torch.pi / nlon
......@@ -55,6 +76,27 @@ def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False,
class DiceLossS2(nn.Module):
"""
Dice loss for spherical segmentation tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
mode : str, optional
Aggregation mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100, mode: str = "micro"):
super().__init__()
......@@ -73,6 +115,21 @@ class DiceLossS2(nn.Module):
self.register_buffer("weight", weight.unsqueeze(0))
def forward(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Dice loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
Returns
-------
torch.Tensor
Dice loss value
"""
prd = nn.functional.softmax(prd, dim=1)
# mask values
......@@ -113,6 +170,24 @@ class DiceLossS2(nn.Module):
class CrossEntropyLossS2(nn.Module):
"""
Cross-entropy loss for spherical classification tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Label smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100):
......@@ -130,6 +205,21 @@ class CrossEntropyLossS2(nn.Module):
self.register_buffer("quad_weights", q)
def forward(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the cross-entropy loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
Returns
-------
torch.Tensor
Cross-entropy loss value
"""
# compute log softmax
logits = nn.functional.log_softmax(prd, dim=1)
......@@ -141,6 +231,24 @@ class CrossEntropyLossS2(nn.Module):
class FocalLossS2(nn.Module):
"""
Focal loss for spherical classification tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Label smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, smooth: float = 0, ignore_index: int = -100):
......@@ -158,6 +266,25 @@ class FocalLossS2(nn.Module):
self.register_buffer("quad_weights", q)
def forward(self, prd: torch.Tensor, tar: torch.Tensor, alpha: float = 0.25, gamma: float = 2):
"""
Forward pass of the focal loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
alpha : float, optional
Alpha parameter for focal loss, by default 0.25
gamma : float, optional
Gamma parameter for focal loss, by default 2
Returns
-------
torch.Tensor
Focal loss value
"""
# compute logits
logits = nn.functional.log_softmax(prd, dim=1)
......@@ -232,22 +359,101 @@ class SphericalLossBase(nn.Module, ABC):
class SquaredL2LossS2(SphericalLossBase):
"""
Squared L2 loss for spherical regression tasks.
Computes the squared difference between prediction and target tensors.
"""
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""
Compute squared L2 loss term.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Squared difference between prediction and target
"""
return torch.square(prd - tar)
class L1LossS2(SphericalLossBase):
"""
L1 loss for spherical regression tasks.
Computes the absolute difference between prediction and target tensors.
"""
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""
Compute L1 loss term.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Absolute difference between prediction and target
"""
return torch.abs(prd - tar)
class L2LossS2(SquaredL2LossS2):
"""
L2 loss for spherical regression tasks.
Computes the square root of the squared L2 loss.
"""
def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor:
"""
Apply square root to get L2 norm.
Parameters
-----------
loss : torch.Tensor
Integrated squared loss
Returns
-------
torch.Tensor
Square root of the loss (L2 norm)
"""
return torch.sqrt(loss)
class W11LossS2(SphericalLossBase):
"""
W11 loss for spherical regression tasks.
Computes the L1 norm of the gradient differences between prediction and target.
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
"""
Initialize W11 loss.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
"""
super().__init__(nlat=nlat, nlon=nlon, grid=grid)
# Set up grid and domain for FFT
l_phi = 2 * torch.pi # domain size
......@@ -305,31 +511,70 @@ class NormalLossS2(SphericalLossBase):
self.register_buffer("k_theta_mesh", k_theta_mesh)
def compute_gradients(self, x):
"""
Compute gradients of the input tensor using FFT.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, nlat, nlon) or (nlat, nlon)
Returns
-------
tuple
Tuple of (grad_phi, grad_theta) gradients
"""
# Make sure x is reshaped to have a batch dimension if it's missing
if x.dim() == 2:
x = x.unsqueeze(0) # Add batch dimension
x_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(x)).real
x_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(x)).real
return x_prime_fft2_theta_h, x_prime_fft2_phi_h
# Compute gradients using FFT
grad_phi = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(x)).real
grad_theta = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(x)).real
return grad_phi, grad_theta
def compute_normals(self, x):
x = x.to(torch.float32)
# Ensure x has a batch dimension
if x.dim() == 2:
x = x.unsqueeze(0)
"""
Compute surface normals from the input tensor.
grad_lat, grad_lon = self.compute_gradients(x)
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, nlat, nlon) or (nlat, nlon)
# Create 3D normal vectors
ones = torch.ones_like(x)
normals = torch.stack([-grad_lon, -grad_lat, ones], dim=1)
Returns
-------
torch.Tensor
Normal vectors with shape (batch, 3, nlat, nlon)
"""
grad_phi, grad_theta = self.compute_gradients(x)
# Construct normal vectors: (-grad_theta, -grad_phi, 1)
normals = torch.stack([-grad_theta, -grad_phi, torch.ones_like(x)], dim=1)
# Normalize
norm = torch.norm(normals, dim=1, keepdim=True)
normals = normals / (norm + 1e-8)
# Normalize along component dimension
normals = F.normalize(normals, p=2, dim=1)
return normals
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""
Compute combined L1 and normal consistency loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Combined loss term
"""
# Handle dimensions for both prediction and target
# Ensure we have at least a batch dimension
if prd.dim() == 2:
......@@ -337,15 +582,18 @@ class NormalLossS2(SphericalLossBase):
if tar.dim() == 2:
tar = tar.unsqueeze(0)
# For 4D tensors (batch, channel, height, width), remove channel if it's 1
if prd.dim() == 4 and prd.size(1) == 1:
prd = prd.squeeze(1)
if tar.dim() == 4 and tar.size(1) == 1:
tar = tar.squeeze(1)
# L1 loss term
l1_loss = torch.abs(prd - tar)
pred_normals = self.compute_normals(prd)
# Normal consistency loss
prd_normals = self.compute_normals(prd)
tar_normals = self.compute_normals(tar)
# Compute cosine similarity
normal_loss = 1 - torch.sum(pred_normals * tar_normals, dim=1, keepdim=True)
return normal_loss
# Cosine similarity between normals
cos_sim = torch.sum(prd_normals * tar_normals, dim=1)
normal_loss = 1 - cos_sim
# Combine losses (equal weighting)
combined_loss = l1_loss + normal_loss.unsqueeze(1)
return combined_loss
......@@ -49,6 +49,31 @@ def _get_stats_multiclass(
quad_weights: torch.Tensor,
ignore_index: Optional[int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute multiclass statistics (TP, FP, FN, TN) on the sphere using quadrature weights.
This function computes true positives, false positives, false negatives, and true negatives
for multiclass classification on spherical data, properly weighted by quadrature weights
to account for the spherical geometry.
Parameters
-----------
output : torch.LongTensor
Predicted class labels
target : torch.LongTensor
Ground truth class labels
num_classes : int
Number of classes in the classification task
quad_weights : torch.Tensor
Quadrature weights for spherical integration
ignore_index : Optional[int]
Index to ignore in the computation (e.g., for padding or invalid regions)
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple containing (tp_count, fp_count, fn_count, tn_count) for each class
"""
batch_size, *dims = output.shape
num_elements = torch.prod(torch.tensor(dims)).long()
......@@ -88,10 +113,46 @@ def _get_stats_multiclass(
def _predict_classes(logits: torch.Tensor) -> torch.Tensor:
"""
Convert logits to class predictions using softmax and argmax.
Parameters
-----------
logits : torch.Tensor
Input logits tensor
Returns
-------
torch.Tensor
Predicted class labels
"""
return torch.argmax(torch.softmax(logits, dim=1), dim=1, keepdim=False)
class BaseMetricS2(nn.Module):
"""
Base class for spherical metrics that properly handle spherical geometry.
This class provides the foundation for computing metrics on spherical data
by using quadrature weights to account for the non-uniform area distribution
on the sphere.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__()
......@@ -108,6 +169,21 @@ class BaseMetricS2(nn.Module):
self.register_buffer("weight", weight.unsqueeze(0))
def _forward(self, pred: torch.Tensor, truth: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute base statistics (TP, FP, FN, TN) for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple containing (tp, fp, fn, tn) statistics
"""
# convert logits to class predictions
pred_class = _predict_classes(pred)
......@@ -138,11 +214,47 @@ class BaseMetricS2(nn.Module):
class IntersectionOverUnionS2(BaseMetricS2):
"""
Intersection over Union (IoU) metric for spherical data.
Computes the IoU score for multiclass classification on the sphere,
properly weighted by quadrature weights to account for spherical geometry.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
def forward(self, pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
"""
Compute IoU score for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
torch.Tensor
IoU score
"""
tp, fp, fn, tn = self._forward(pred, truth)
# compute score
......@@ -162,11 +274,47 @@ class IntersectionOverUnionS2(BaseMetricS2):
class AccuracyS2(BaseMetricS2):
"""
Accuracy metric for spherical data.
Computes the accuracy score for multiclass classification on the sphere,
properly weighted by quadrature weights to account for spherical geometry.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular", weight: torch.Tensor = None, ignore_index: int = -100, mode: str = "micro"):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
def forward(self, pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
"""
Compute accuracy score for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
torch.Tensor
Accuracy score
"""
tp, fp, fn, tn = self._forward(pred, truth)
# compute score
......
......@@ -41,6 +41,27 @@ from torch_harmonics import InverseRealSHT
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
"""
Internal function to fill tensor with truncated normal distribution values.
Parameters
-----------
tensor : torch.Tensor
Tensor to fill with values
mean : float
Mean of the normal distribution
std : float
Standard deviation of the normal distribution
a : float
Lower bound for truncation
b : float
Upper bound for truncation
Returns
-------
torch.Tensor
The filled tensor
"""
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
......@@ -96,12 +117,28 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
@torch.jit.script
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
Parameters
-----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Dropout probability, by default 0.0
training : bool, optional
Whether in training mode, by default False
Returns
-------
torch.Tensor
Output tensor with potential drop path applied
"""
if drop_prob == 0.0 or not training:
return x
......@@ -117,14 +154,50 @@ class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
"""
Initialize DropPath module.
Parameters
-----------
drop_prob : float, optional
Dropout probability, by default None
"""
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
"""
Forward pass with drop path.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with potential drop path applied
"""
return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
"""
Patch embedding layer for vision transformers.
Parameters
-----------
img_size : tuple, optional
Input image size (height, width), by default (224, 224)
patch_size : tuple, optional
Patch size (height, width), by default (16, 16)
in_chans : int, optional
Number of input channels, by default 3
embed_dim : int, optional
Embedding dimension, by default 768
"""
def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
super(PatchEmbed, self).__init__()
self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1]))
......@@ -137,6 +210,19 @@ class PatchEmbed(nn.Module):
self.proj.bias.is_shared_mp = ["spatial"]
def forward(self, x):
"""
Forward pass of patch embedding.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, height, width)
Returns
-------
torch.Tensor
Embedded patches with shape (batch, embed_dim, num_patches)
"""
# gather input
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
......@@ -146,6 +232,29 @@ class PatchEmbed(nn.Module):
class MLP(nn.Module):
"""
Multi-layer perceptron with optional checkpointing.
Parameters
-----------
in_features : int
Number of input features
hidden_features : int, optional
Number of hidden features, by default None (same as in_features)
out_features : int, optional
Number of output features, by default None (same as in_features)
act_layer : nn.Module, optional
Activation layer, by default nn.ReLU
output_bias : bool, optional
Whether to use bias in output layer, by default False
drop_rate : float, optional
Dropout rate, by default 0.0
checkpointing : bool, optional
Whether to use gradient checkpointing, by default False
gain : float, optional
Gain factor for output initialization, by default 1.0
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
super(MLP, self).__init__()
self.checkpointing = checkpointing
......@@ -179,9 +288,35 @@ class MLP(nn.Module):
@torch.jit.ignore
def checkpoint_forward(self, x):
"""
Forward pass with gradient checkpointing.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor
"""
return checkpoint(self.fwd, x)
def forward(self, x):
"""
Forward pass of the MLP.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor
"""
if self.checkpointing:
return self.checkpoint_forward(x)
else:
......@@ -194,6 +329,20 @@ class RealFFT2(nn.Module):
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None):
"""
Initialize RealFFT2 module.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum l mode, by default None (same as nlat)
mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1)
"""
super(RealFFT2, self).__init__()
self.nlat = nlat
......@@ -202,6 +351,19 @@ class RealFFT2(nn.Module):
self.mmax = mmax or self.nlon // 2 + 1
def forward(self, x):
"""
Forward pass of RealFFT2.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Output tensor with shape (batch, channels, nlat, mmax)
"""
y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
return y
......@@ -209,10 +371,24 @@ class RealFFT2(nn.Module):
class InverseRealFFT2(nn.Module):
"""
Helper routine to wrap FFT similarly to the SHT
Helper routine to wrap inverse FFT similarly to the SHT
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None):
"""
Initialize InverseRealFFT2 module.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum l mode, by default None (same as nlat)
mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1)
"""
super(InverseRealFFT2, self).__init__()
self.nlat = nlat
......@@ -221,6 +397,19 @@ class InverseRealFFT2(nn.Module):
self.mmax = mmax or self.nlon // 2 + 1
def forward(self, x):
"""
Forward pass of InverseRealFFT2.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, mmax)
Returns
-------
torch.Tensor
Output tensor with shape (batch, channels, nlat, nlon)
"""
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
......@@ -230,6 +419,24 @@ class LayerNorm(nn.Module):
"""
def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
"""
Initialize LayerNorm module.
Parameters
-----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type, by default None
"""
super().__init__()
self.channel_dim = -3
......@@ -237,164 +444,270 @@ class LayerNorm(nn.Module):
self.norm = nn.LayerNorm(normalized_shape=in_channels, eps=1e-6, elementwise_affine=elementwise_affine, bias=bias, device=device, dtype=dtype)
def forward(self, x):
"""
Forward pass of LayerNorm.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Normalized tensor
"""
return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim)
class SpectralConvS2(nn.Module):
"""
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
Spectral convolution layer for spherical data.
Parameters
-----------
forward_transform : nn.Module
Forward transform (e.g., RealSHT)
inverse_transform : nn.Module
Inverse transform (e.g., InverseRealSHT)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
gain : float, optional
Gain factor for weight initialization, by default 2.0
operator_type : str, optional
Type of spectral operator, by default "driscoll-healy"
lr_scale_exponent : int, optional
Learning rate scale exponent, by default 0
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
super().__init__()
super(SpectralConvS2, self).__init__()
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon)
# remember factorization details
self.in_channels = in_channels
self.out_channels = out_channels
self.operator_type = operator_type
self.lr_scale_exponent = lr_scale_exponent
assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon
weight_shape = [out_channels, in_channels]
if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon]
self.contract_func = "...ilm,oilm->...olm"
elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
self.contract_func = "...ilm,oilnm->...oln"
elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat]
self.contract_func = "...ilm,oil->...olm"
else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
# initialize the weights
scale = math.sqrt(gain / in_channels)
self.weight = nn.Parameter(scale * torch.randn(*weight_shape, dtype=torch.complex64))
self.weight = nn.Parameter(scale * torch.randn(out_channels, in_channels, dtype=torch.cfloat))
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
self.bias = nn.Parameter(torch.zeros(out_channels, dtype=torch.cfloat))
else:
self.bias = None
def forward(self, x):
"""
Forward pass of spectral convolution.
dtype = x.dtype
x = x.float()
residual = x
Parameters
-----------
x : torch.Tensor
Input tensor
with torch.autocast(device_type="cuda", enabled=False):
Returns
-------
torch.Tensor
Output tensor after spectral convolution
"""
# apply forward transform
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = torch.einsum(self.contract_func, x, self.weight)
# apply spectral convolution
x = torch.einsum("bilm,oim->bolm", x, self.weight)
with torch.autocast(device_type="cuda", enabled=False):
# apply inverse transform
x = self.inverse_transform(x)
if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)
# add bias if present
if self.bias is not None:
x = x + self.bias.view(1, -1, 1, 1)
return x
return x, residual
class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
"""
Returns standard sequence based position embedding
Abstract base class for position embeddings on spherical data.
Parameters
-----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__()
super(PositionEmbedding, self).__init__()
self.img_shape = img_shape
self.grid = grid
self.num_chans = num_chans
@abc.abstractmethod
def forward(self, x: torch.Tensor):
"""
Abstract forward method for position embedding.
Parameters
-----------
x : torch.Tensor
Input tensor
"""
pass
return x + self.position_embeddings
class SequencePositionEmbedding(PositionEmbedding):
"""
Returns standard sequence based position embedding
Sequence-based position embedding for spherical data.
This module adds position embeddings based on the sequence of latitude and longitude
coordinates, providing spatial context to the model.
Parameters
-----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(SequencePositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# create position embeddings
pos_embed = torch.zeros(1, num_chans, img_shape[0], img_shape[1])
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_buffer("pos_embed", pos_embed)
with torch.no_grad():
def forward(self, x: torch.Tensor):
"""
Forward pass of sequence position embedding.
# alternating custom position embeddings
pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
denom = torch.pow(10000, 2 * k / self.num_chans)
Parameters
-----------
x : torch.Tensor
Input tensor
pos_embed = torch.where(k % 2 == 0, torch.sin(pos / denom), torch.cos(pos / denom))
Returns
-------
torch.Tensor
Tensor with position embeddings added
"""
return x + self.pos_embed
# register tensor
self.register_buffer("position_embeddings", pos_embed.float())
class SpectralPositionEmbedding(PositionEmbedding):
"""
Returns position embeddings for the spherical transformer
r"""
Spectral position embedding for spherical data.
This module adds position embeddings in the spectral domain using spherical harmonics,
providing spectral context to the model.
Parameters
-----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(SpectralPositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# compute maximum required frequency and prepare isht
lmax = math.floor(math.sqrt(self.num_chans)) + 1
isht = InverseRealSHT(nlat=self.img_shape[0], nlon=self.img_shape[1], lmax=lmax, mmax=lmax, grid=grid)
# fill position embedding
with torch.no_grad():
pos_embed_freq = torch.zeros(1, self.num_chans, isht.lmax, isht.mmax, dtype=torch.complex64)
for i in range(self.num_chans):
l = math.floor(math.sqrt(i))
m = i - l**2 - l
if m < 0:
pos_embed_freq[0, i, l, -m] = 1.0j
else:
pos_embed_freq[0, i, l, m] = 1.0
# create spectral position embeddings
pos_embed = torch.zeros(1, num_chans, img_shape[0], img_shape[1] // 2 + 1, dtype=torch.cfloat)
nn.init.trunc_normal_(pos_embed.real, std=0.02)
nn.init.trunc_normal_(pos_embed.imag, std=0.02)
self.register_buffer("pos_embed", pos_embed)
# compute spatial position embeddings
pos_embed = isht(pos_embed_freq)
def forward(self, x: torch.Tensor):
"""
Forward pass of spectral position embedding.
# normalization
pos_embed = pos_embed / torch.amax(pos_embed.abs(), dim=(-1, -2), keepdim=True)
Parameters
-----------
x : torch.Tensor
Input tensor
# register tensor
self.register_buffer("position_embeddings", pos_embed)
Returns
-------
torch.Tensor
Tensor with spectral position embeddings added
"""
return x + self.pos_embed
class LearnablePositionEmbedding(PositionEmbedding):
"""
Returns position embeddings for the spherical transformer
r"""
Learnable position embedding for spherical data.
This module adds learnable position embeddings that are optimized during training,
allowing the model to learn optimal spatial representations.
Parameters
-----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
embed_type : str, optional
Embedding type ("lat", "lon", or "both"), by default "lat"
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
super(LearnablePositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
self.embed_type = embed_type
if embed_type == "lat":
# latitude embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, img_shape[0], 1))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
elif embed_type == "lon":
# longitude embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, 1, img_shape[1]))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
elif embed_type == "latlon":
# full lat-lon embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, img_shape[0], img_shape[1]))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
else:
raise ValueError(f"Unknown embedding type {embed_type}")
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
def forward(self, x: torch.Tensor):
"""
Forward pass of learnable position embedding.
if embed_type == "latlon":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], self.img_shape[1]))
elif embed_type == "lat":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], 1))
else:
raise ValueError(f"Unknown learnable position embedding type {embed_type}")
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with learnable position embeddings added
"""
return x + self.pos_embed
# class SpiralPositionEmbedding(PositionEmbedding):
# """
......
......@@ -50,6 +50,35 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module):
r"""
Discrete-continuous encoder for spherical neural operators.
This module performs downsampling using discrete-continuous convolutions on the sphere,
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
inp_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(
self,
in_shape=(721, 1440),
......@@ -81,6 +110,19 @@ class DiscreteContinuousEncoder(nn.Module):
)
def forward(self, x):
"""
Forward pass of the discrete-continuous encoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Encoded tensor with reduced spatial resolution
"""
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -92,6 +134,37 @@ class DiscreteContinuousEncoder(nn.Module):
class DiscreteContinuousDecoder(nn.Module):
r"""
Discrete-continuous decoder for spherical neural operators.
This module performs upsampling using either spherical harmonic transforms or resampling,
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (721, 1440)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
inp_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsample_sht : bool, optional
Whether to use SHT for upsampling, by default False
"""
def __init__(
self,
in_shape=(480, 960),
......@@ -132,6 +205,19 @@ class DiscreteContinuousDecoder(nn.Module):
)
def forward(self, x):
"""
Forward pass of the discrete-continuous decoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Decoded tensor with restored spatial resolution
"""
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -274,7 +360,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
class LocalSphericalNeuralOperator(nn.Module):
"""
r"""
LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
......@@ -282,43 +368,48 @@ class LocalSphericalNeuralOperator(nn.Module):
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
img_size : tuple, optional
Input image size (nlat, nlon), by default (128, 256)
grid : str, optional
Grid type for input/output, by default "equiangular"
grid_internal : str, optional
Grid type for internal processing, by default "legendre-gauss"
scale_factor : int, optional
Scale factor to use, by default 3
Scale factor for resolution changes, by default 3
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
embed_dim : int, optional
Dimension of the embeddings, by default 256
Embedding dimension, by default 256
num_layers : int, optional
Number of layers in the network, by default 4
Number of layers, by default 4
activation_function : str, optional
Activation function to use, by default "gelu"
encoder_kernel_shape : int, optional
size of the encoder kernel
filter_basis_type: Optional[str]: str, optional
filter basis type
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
Activation function name, by default "gelu"
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
encoder_kernel_shape : tuple, optional
Kernel shape for encoder, by default (3, 3)
filter_basis_type : str, optional
Filter basis type, by default "morlet"
use_mlp : bool, optional
Whether to use MLP layers, by default True
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
Drop path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
sfno_block_frequency : int, optional
Hopw often a (global) SFNO block is used, by default 2
Frequency of SFNO blocks, by default 2
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
big_skip : bool, optional
Whether to add a single large skip connection, by default True
pos_embed : bool, optional
Whether to use positional embedding, by default True
Hard thresholding fraction, by default 1.0
residual_prediction : bool, optional
Whether to use residual prediction, by default False
pos_embed : str, optional
Position embedding type, by default "none"
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
bias : bool, optional
......@@ -497,6 +588,19 @@ class LocalSphericalNeuralOperator(nn.Module):
return x
def forward(self, x):
"""
Forward pass through the complete LSNO model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if self.residual_prediction:
residual = x
......
......@@ -54,6 +54,34 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
class OverlapPatchMerging(nn.Module):
"""
Overlap patch merging module for spherical segformer.
This module performs patch merging with overlapping patches using discrete-continuous
convolutions on the sphere, followed by layer normalization.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (481, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_channels : int, optional
Number of input channels, by default 3
out_channels : int, optional
Number of output channels, by default 64
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(
self,
in_shape=(721, 1440),
......@@ -89,11 +117,32 @@ class OverlapPatchMerging(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""
Forward pass of the overlap patch merging module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Merged patches with layer normalization
"""
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -109,6 +158,38 @@ class OverlapPatchMerging(nn.Module):
class MixFFN(nn.Module):
"""
Mix FFN module for spherical segformer.
This module implements a feed-forward network that combines MLP operations
with discrete-continuous convolutions on the sphere.
Parameters
-----------
shape : tuple
Shape (nlat, nlon) of the input
inout_channels : int
Number of input/output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
grid : str, optional
Grid type, by default "equiangular"
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : nn.Module, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
drop_path : float, optional
Drop path rate, by default 0.0
"""
def __init__(
self,
shape,
......@@ -161,6 +242,14 @@ class MixFFN(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
......@@ -170,7 +259,19 @@ class MixFFN(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Mix FFN module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after Mix FFN processing
"""
residual = x
# norm
......@@ -194,6 +295,35 @@ class MixFFN(nn.Module):
class AttentionWrapper(nn.Module):
"""
Attention wrapper for spherical segformer.
This module wraps attention mechanisms (neighborhood or global) with optional
normalization and drop path regularization.
Parameters
-----------
channels : int
Number of channels
shape : tuple
Shape (nlat, nlon) of the input
grid : str
Grid type
heads : int
Number of attention heads
pre_norm : bool, optional
Whether to apply normalization before attention, by default False
attention_drop_rate : float, optional
Dropout rate for attention, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(
self,
channels,
......@@ -252,7 +382,19 @@ class AttentionWrapper(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the attention wrapper.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after attention processing
"""
residual = x
if self.norm is not None:
x = x.permute(0, 2, 3, 1)
......@@ -271,6 +413,49 @@ class AttentionWrapper(nn.Module):
class TransformerBlock(nn.Module):
"""
Transformer block for spherical segformer.
This block combines patch merging, attention, and Mix FFN operations
in a hierarchical structure for processing spherical data.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of repetitions, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.GELU
att_drop_rate : float, optional
Dropout rate for attention, by default 0.0
drop_path_rates : float, optional
Drop path rates, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(
self,
in_shape,
......@@ -363,6 +548,19 @@ class TransformerBlock(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the transformer block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after transformer block processing
"""
x = self.fwd(x)
# apply norm
......@@ -374,6 +572,43 @@ class TransformerBlock(nn.Module):
class Upsampling(nn.Module):
"""
Upsampling module for spherical segformer.
This module performs upsampling using either discrete-continuous transposed convolutions
or bilinear resampling on spherical data.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : nn.Module, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
upsampling_method : str, optional
Upsampling method ("conv" or "bilinear"), by default "conv"
"""
def __init__(
self,
in_shape,
......@@ -429,6 +664,19 @@ class Upsampling(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the upsampling module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Upsampled tensor
"""
x = self.upsample(self.mlp(x))
return x
......@@ -606,6 +854,14 @@ class SphericalSegformer(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
......@@ -615,7 +871,19 @@ class SphericalSegformer(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""
Forward pass through the complete spherical segformer model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder:
features = []
feat = x
......
......@@ -52,6 +52,36 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module):
"""
Discrete-continuous encoder for spherical transformers.
This module performs downsampling using discrete-continuous convolutions on the sphere,
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(
self,
in_shape=(721, 1440),
......@@ -83,6 +113,19 @@ class DiscreteContinuousEncoder(nn.Module):
)
def forward(self, x):
"""
Forward pass of the discrete-continuous encoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Encoded tensor with reduced spatial resolution
"""
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -94,6 +137,38 @@ class DiscreteContinuousEncoder(nn.Module):
class DiscreteContinuousDecoder(nn.Module):
"""
Discrete-continuous decoder for spherical transformers.
This module performs upsampling using either spherical harmonic transforms or resampling,
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (721, 1440)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsample_sht : bool, optional
Whether to use SHT for upsampling, by default False
"""
def __init__(
self,
in_shape=(480, 960),
......@@ -134,6 +209,19 @@ class DiscreteContinuousDecoder(nn.Module):
)
def forward(self, x):
"""
Forward pass of the discrete-continuous decoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Decoded tensor with restored spatial resolution
"""
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -147,7 +235,45 @@ class DiscreteContinuousDecoder(nn.Module):
class SphericalAttentionBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Spherical attention block for transformers on the sphere.
This module implements a single attention block that can use either global attention
or neighborhood attention on spherical data, followed by an optional MLP.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
num_heads : int, optional
Number of attention heads, by default 1
mlp_ratio : float, optional
Ratio of MLP hidden dimension to output dimension, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : nn.Module, optional
Activation layer, by default nn.GELU
norm_layer : str, optional
Normalization layer type, by default "none"
use_mlp : bool, optional
Whether to use MLP after attention, by default True
bias : bool, optional
Whether to use bias, by default False
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
"""
def __init__(
......@@ -467,6 +593,19 @@ class SphericalTransformer(nn.Module):
return x
def forward(self, x):
"""
Forward pass through the complete spherical transformer model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if self.residual_prediction:
residual = x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment