Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
e4879676
Commit
e4879676
authored
Jun 26, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Added docstrings to many methods
parent
b5c410c0
Changes
29
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2704 additions
and
189 deletions
+2704
-189
examples/baseline_models/segformer.py
examples/baseline_models/segformer.py
+153
-1
examples/baseline_models/transformer.py
examples/baseline_models/transformer.py
+114
-2
examples/baseline_models/unet.py
examples/baseline_models/unet.py
+97
-3
examples/depth/train.py
examples/depth/train.py
+99
-3
examples/model_registry.py
examples/model_registry.py
+70
-0
examples/segmentation/train.py
examples/segmentation/train.py
+97
-3
examples/shallow_water_equations/train.py
examples/shallow_water_equations/train.py
+93
-3
setup.py
setup.py
+25
-2
tests/test_cache.py
tests/test_cache.py
+11
-0
torch_harmonics/_disco_convolution.py
torch_harmonics/_disco_convolution.py
+58
-0
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+304
-5
torch_harmonics/cache.py
torch_harmonics/cache.py
+26
-0
torch_harmonics/distributed/distributed_resample.py
torch_harmonics/distributed/distributed_resample.py
+78
-0
torch_harmonics/distributed/primitives.py
torch_harmonics/distributed/primitives.py
+95
-3
torch_harmonics/examples/losses.py
torch_harmonics/examples/losses.py
+273
-25
torch_harmonics/examples/metrics.py
torch_harmonics/examples/metrics.py
+150
-2
torch_harmonics/examples/models/_layers.py
torch_harmonics/examples/models/_layers.py
+422
-109
torch_harmonics/examples/models/lsno.py
torch_harmonics/examples/models/lsno.py
+128
-24
torch_harmonics/examples/models/s2segformer.py
torch_harmonics/examples/models/s2segformer.py
+271
-3
torch_harmonics/examples/models/s2transformer.py
torch_harmonics/examples/models/s2transformer.py
+140
-1
No files found.
examples/baseline_models/segformer.py
View file @
e4879676
...
...
@@ -41,6 +41,24 @@ from functools import partial
class
OverlapPatchMerging
(
nn
.
Module
):
"""
OverlapPatchMerging layer for merging patches.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
kernel_shape : tuple
Kernel shape for convolution
bias : bool, optional
Whether to use bias, by default False
"""
def
__init__
(
self
,
in_shape
=
(
721
,
1440
),
...
...
@@ -72,11 +90,32 @@ class OverlapPatchMerging(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
"""
Forward pass through the OverlapPatchMerging layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after patch merging
"""
x
=
self
.
conv
(
x
)
# permute
...
...
@@ -88,6 +127,30 @@ class OverlapPatchMerging(nn.Module):
class
MixFFN
(
nn
.
Module
):
"""
MixFFN module combining MLP and depthwise convolution.
Parameters
-----------
shape : tuple
Input shape (height, width)
inout_channels : int
Number of input/output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP layers, by default True
kernel_shape : tuple, optional
Kernel shape for depthwise convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : callable, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
drop_path : float, optional
Drop path rate, by default 0.0
"""
def
__init__
(
self
,
shape
,
...
...
@@ -124,6 +187,14 @@ class MixFFN(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
...
...
@@ -133,7 +204,19 @@ class MixFFN(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass through the MixFFN module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after processing
"""
residual
=
x
# norm
...
...
@@ -162,6 +245,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
chans
,
num_heads
=
8
,
dropout
=
0.0
,
bias
=
True
):
...
...
@@ -169,6 +263,19 @@ class GlobalAttention(nn.Module):
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
=
chans
,
num_heads
=
num_heads
,
dropout
=
dropout
,
batch_first
=
True
,
bias
=
bias
)
def
forward
(
self
,
x
):
"""
Forward pass through the GlobalAttention module.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (B, C, H, W)
Returns
-------
torch.Tensor
Output tensor of shape (B, C, H, W)
"""
# x: B, C, H, W
B
,
H
,
W
,
C
=
x
.
shape
# flatten spatial dims
...
...
@@ -181,6 +288,30 @@ class GlobalAttention(nn.Module):
class
AttentionWrapper
(
nn
.
Module
):
"""
Wrapper for different attention mechanisms.
Parameters
-----------
channels : int
Number of channels
shape : tuple
Input shape (height, width)
heads : int
Number of attention heads
pre_norm : bool, optional
Whether to apply normalization before attention, by default False
attention_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
channels
,
shape
,
heads
,
pre_norm
=
False
,
attention_drop_rate
=
0.0
,
drop_path
=
0.0
,
attention_mode
=
"neighborhood"
,
kernel_shape
=
(
7
,
7
),
bias
=
True
):
super
().
__init__
()
...
...
@@ -203,11 +334,32 @@ class AttentionWrapper(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass through the AttentionWrapper.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with residual connection
"""
residual
=
x
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
if
self
.
norm
is
not
None
:
...
...
examples/baseline_models/transformer.py
View file @
e4879676
...
...
@@ -57,11 +57,46 @@ class Encoder(nn.Module):
self
.
conv
=
nn
.
Conv2d
(
in_chans
,
out_chans
,
kernel_size
=
kernel_shape
,
bias
=
bias
,
stride
=
(
stride_h
,
stride_w
),
padding
=
(
pad_h
,
pad_w
),
groups
=
groups
)
def
forward
(
self
,
x
):
"""
Forward pass through the Encoder layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after encoding
"""
x
=
self
.
conv
(
x
)
return
x
class
Decoder
(
nn
.
Module
):
"""
Decoder module for upsampling and feature processing.
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (721, 1440)
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
groups : int, optional
Number of groups for convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsampling_method : str, optional
Upsampling method ("conv", "pixel_shuffle"), by default "conv"
"""
def
__init__
(
self
,
in_shape
=
(
480
,
960
),
out_shape
=
(
721
,
1440
),
in_chans
=
2
,
out_chans
=
2
,
kernel_shape
=
(
3
,
3
),
groups
=
1
,
bias
=
False
,
upsampling_method
=
"conv"
):
super
().
__init__
()
self
.
out_shape
=
out_shape
...
...
@@ -87,6 +122,19 @@ class Decoder(nn.Module):
raise
ValueError
(
f
"Unknown upsampling method
{
upsampling_method
}
"
)
def
forward
(
self
,
x
):
"""
Forward pass through the Decoder layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after decoding and upsampling
"""
x
=
self
.
upsample
(
x
)
return
x
...
...
@@ -97,6 +145,17 @@ class GlobalAttention(nn.Module):
Input shape: (B, C, H, W)
Output shape: (B, C, H, W) with residual skip.
Parameters
-----------
chans : int
Number of channels
num_heads : int, optional
Number of attention heads, by default 8
dropout : float, optional
Dropout rate, by default 0.0
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
chans
,
num_heads
=
8
,
dropout
=
0.0
,
bias
=
True
):
...
...
@@ -104,6 +163,19 @@ class GlobalAttention(nn.Module):
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
=
chans
,
num_heads
=
num_heads
,
dropout
=
dropout
,
batch_first
=
True
,
bias
=
bias
)
def
forward
(
self
,
x
):
"""
Forward pass through the GlobalAttention module.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (B, C, H, W)
Returns
-------
torch.Tensor
Output tensor of shape (B, C, H, W)
"""
# x: B, C, H, W
B
,
H
,
W
,
C
=
x
.
shape
# flatten spatial dims
...
...
@@ -118,8 +190,36 @@ class GlobalAttention(nn.Module):
class
AttentionBlock
(
nn
.
Module
):
"""
Neighborhood attention block based on Natten.
"""
Parameters
-----------
in_shape : tuple, optional
Input shape (height, width), by default (480, 960)
out_shape : tuple, optional
Output shape (height, width), by default (480, 960)
chans : int, optional
Number of channels, by default 2
num_heads : int, optional
Number of attention heads, by default 1
mlp_ratio : float, optional
Ratio of MLP hidden dim to input dim, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : callable, optional
Activation function, by default nn.GELU
norm_layer : str, optional
Normalization layer type, by default "none"
use_mlp : bool, optional
Whether to use MLP, by default True
bias : bool, optional
Whether to use bias, by default True
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
"""
def
__init__
(
self
,
in_shape
=
(
480
,
960
),
...
...
@@ -186,7 +286,19 @@ class AttentionBlock(nn.Module):
self
.
skip1
=
nn
.
Identity
()
def
forward
(
self
,
x
):
"""
Forward pass through the AttentionBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with residual connections
"""
residual
=
x
x
=
self
.
norm0
(
x
)
...
...
examples/baseline_models/unet.py
View file @
e4879676
...
...
@@ -140,12 +140,33 @@ class DownsamplingBlock(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass through the DownsamplingBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after downsampling
"""
# skip connection
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
...
...
@@ -166,6 +187,36 @@ class DownsamplingBlock(nn.Module):
class
UpsamplingBlock
(
nn
.
Module
):
"""
Upsampling block for UNet architecture.
Parameters
-----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -274,12 +325,33 @@ class UpsamplingBlock(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass through the UpsamplingBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after upsampling
"""
# skip connection
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
...
...
@@ -304,6 +376,7 @@ class UNet(nn.Module):
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
Kernel shape for convolutions
scale_factor: int, optional
Scale factor to use, by default 2
in_chans : int, optional
...
...
@@ -336,11 +409,12 @@ class UNet(nn.Module):
... scale_factor=4,
... in_chans=2,
... num_classes=2,
... embed_dims=[64, 128, 256, 512],)
... embed_dims=[16, 32, 64, 128],
... depths=[2, 2, 2, 2],
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def
__init__
(
self
,
img_shape
=
(
128
,
256
),
...
...
@@ -440,6 +514,14 @@ class UNet(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
m
.
bias
is
not
None
:
...
...
@@ -450,7 +532,19 @@ class UNet(nn.Module):
def
forward
(
self
,
x
):
"""
Forward pass through the UNet model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, num_classes, height, width)
"""
# encoder:
features
=
[]
feat
=
x
...
...
examples/depth/train.py
View file @
e4879676
...
...
@@ -63,13 +63,37 @@ import wandb
# helper routine for counting number of paramerters in model
def
count_parameters
(
model
):
"""
Count the number of trainable parameters in a model.
Parameters
-----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
exp_dir
,
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
-----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
log_path
=
os
.
path
.
join
(
exp_dir
,
"weights_and_grads"
)
if
not
os
.
path
.
isdir
(
log_path
):
...
...
@@ -97,7 +121,36 @@ def validate_model(
logging
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
"""
Validate a model on a dataset and compute metrics.
Parameters
-----------
model : torch.nn.Module
Model to validate
dataloader : torch.utils.data.DataLoader
DataLoader for validation data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving validation outputs
normalization_in : callable, optional
Normalization function to apply to inputs, by default None
normalization_out : callable, optional
Normalization function to apply to targets, by default None
logging : bool, optional
Whether to save validation plots, by default True
device : torch.device, optional
Device to run validation on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics) where losses is a tensor of per-sample losses
and metrics is a dict of per-sample metric values
"""
model
.
eval
()
num_examples
=
len
(
dataloader
)
...
...
@@ -190,7 +243,50 @@ def train_model(
logging
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
train_dataloader : torch.utils.data.DataLoader
DataLoader for training data
train_sampler : torch.utils.data.Sampler
Sampler for training data
test_dataloader : torch.utils.data.DataLoader
DataLoader for test data
test_sampler : torch.utils.data.Sampler
Sampler for test data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
normalization_in : callable, optional
Normalization function to apply to inputs, by default None
normalization_out : callable, optional
Normalization function to apply to targets, by default None
augmentation : bool, optional
Whether to apply data augmentation, by default False
nepochs : int, optional
Number of training epochs, by default 20
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
exp_dir : str, optional
Experiment directory for logging, by default None
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start
=
time
.
time
()
# set AMP type
...
...
examples/model_registry.py
View file @
e4879676
...
...
@@ -39,6 +39,76 @@ from baseline_models import Transformer, UNet, Segformer
from
torch_harmonics.examples.models
import
SphericalFourierNeuralOperator
,
LocalSphericalNeuralOperator
,
SphericalTransformer
,
SphericalUNet
,
SphericalSegformer
def
get_baseline_models
(
img_size
=
(
128
,
256
),
in_chans
=
3
,
out_chans
=
3
,
residual_prediction
=
False
,
drop_path_rate
=
0.
,
grid
=
"equiangular"
):
"""
Get a registry of baseline models for spherical and planar neural networks.
This function returns a dictionary containing pre-configured model architectures
for various tasks including spherical Fourier neural operators (SFNO), local
spherical neural operators (LSNO), spherical transformers, U-Nets, and Segformers.
Each model is configured with specific hyperparameters optimized for different
computational budgets and performance requirements.
Parameters
-----------
img_size : tuple, optional
Input image size as (height, width), by default (128, 256)
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
residual_prediction : bool, optional
Whether to use residual prediction (add input to output), by default False
drop_path_rate : float, optional
Dropout path rate for regularization, by default 0.0
grid : str, optional
Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular"
Returns
-------
dict
Dictionary mapping model names to partial functions that can be called
to instantiate the corresponding model with the specified parameters.
Available models include:
**Spherical Models:**
- sfno_sc2_layers4_e32: Spherical Fourier Neural Operator (small)
- lsno_sc2_layers4_e32: Local Spherical Neural Operator (small)
- s2unet_sc2_layers4_e128: Spherical U-Net (medium)
- s2transformer_sc2_layers4_e128: Spherical Transformer (global attention, medium)
- s2transformer_sc2_layers4_e256: Spherical Transformer (global attention, large)
- s2ntransformer_sc2_layers4_e128: Spherical Transformer (neighborhood attention, medium)
- s2ntransformer_sc2_layers4_e256: Spherical Transformer (neighborhood attention, large)
- s2segformer_sc2_layers4_e128: Spherical Segformer (global attention, medium)
- s2segformer_sc2_layers4_e256: Spherical Segformer (global attention, large)
- s2nsegformer_sc2_layers4_e128: Spherical Segformer (neighborhood attention, medium)
- s2nsegformer_sc2_layers4_e256: Spherical Segformer (neighborhood attention, large)
**Planar Models:**
- transformer_sc2_layers4_e128: Planar Transformer (global attention, medium)
- transformer_sc2_layers4_e256: Planar Transformer (global attention, large)
- ntransformer_sc2_layers4_e128: Planar Transformer (neighborhood attention, medium)
- ntransformer_sc2_layers4_e256: Planar Transformer (neighborhood attention, large)
- segformer_sc2_layers4_e128: Planar Segformer (global attention, medium)
- segformer_sc2_layers4_e256: Planar Segformer (global attention, large)
- nsegformer_sc2_layers4_e128: Planar Segformer (neighborhood attention, medium)
- nsegformer_sc2_layers4_e256: Planar Segformer (neighborhood attention, large)
- vit_sc2_layers4_e128: Vision Transformer variant (medium)
Examples
--------
>>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1)
>>> model = model_registry['sfno_sc2_layers4_e32']()
>>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Notes
-----
- Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim}
- 'sc2' indicates scale factor of 2 (downsampling by 2)
- 'e32', 'e128', 'e256' indicate embedding dimensions
- 'n' prefix indicates neighborhood attention instead of global attention
- All models use GELU activation and instance normalization by default
"""
# prepare dicts containing models and corresponding metrics
model_registry
=
dict
(
...
...
examples/segmentation/train.py
View file @
e4879676
...
...
@@ -68,13 +68,37 @@ import wandb
# helper routine for counting number of paramerters in model
def
count_parameters
(
model
):
"""
Count the number of trainable parameters in a model.
Parameters
-----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
exp_dir
,
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
-----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
log_path
=
os
.
path
.
join
(
exp_dir
,
"weights_and_grads"
)
if
not
os
.
path
.
isdir
(
log_path
):
...
...
@@ -92,7 +116,34 @@ def log_weights_and_grads(exp_dir, model, iters=1):
# rolls out the FNO and compares to the classical solver
def
validate_model
(
model
,
dataloader
,
loss_fn
,
metrics_fns
,
path_root
,
normalization
=
None
,
logging
=
True
,
device
=
torch
.
device
(
"cpu"
)):
"""
Validate a model on a dataset and compute metrics.
Parameters
-----------
model : torch.nn.Module
Model to validate
dataloader : torch.utils.data.DataLoader
DataLoader for validation data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving validation outputs
normalization : callable, optional
Normalization function to apply to inputs, by default None
logging : bool, optional
Whether to save validation plots, by default True
device : torch.device, optional
Device to run validation on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics) where losses is a tensor of per-sample losses
and metrics is a dict of per-sample metric values
"""
model
.
eval
()
num_examples
=
len
(
dataloader
)
...
...
@@ -178,7 +229,50 @@ def train_model(
logging
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
train_dataloader : torch.utils.data.DataLoader
DataLoader for training data
train_sampler : torch.utils.data.Sampler
Sampler for training data
test_dataloader : torch.utils.data.DataLoader
DataLoader for test data
test_sampler : torch.utils.data.Sampler
Sampler for test data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
max_grad_norm : float, optional
Maximum gradient norm for clipping, by default 0.0
normalization : callable, optional
Normalization function to apply to inputs, by default None
augmentation : callable, optional
Augmentation function to apply to inputs, by default None
nepochs : int, optional
Number of training epochs, by default 20
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
exp_dir : str, optional
Experiment directory for logging, by default None
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start
=
time
.
time
()
# set AMP type
...
...
examples/shallow_water_equations/train.py
View file @
e4879676
...
...
@@ -63,13 +63,35 @@ except:
# helper routine for counting number of paramerters in model
def
count_parameters
(
model
):
"""
Count the number of trainable parameters in a model.
Parameters
-----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
-----------
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
root_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"weights_and_grads"
)
...
...
@@ -97,7 +119,40 @@ def autoregressive_inference(
nics
=
50
,
device
=
torch
.
device
(
"cpu"
),
):
"""
Perform autoregressive inference with a trained model and compare to classical solver.
Parameters
-----------
model : torch.nn.Module
Trained model to evaluate
dataset : torch.utils.data.Dataset
Dataset containing solver and normalization parameters
loss_fn : callable
Loss function for evaluation
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving inference outputs
nsteps : int
Number of solver steps per autoregressive step
autoreg_steps : int, optional
Number of autoregressive steps, by default 10
nskip : int, optional
Skip interval for plotting, by default 1
plot_channel : int, optional
Channel to plot, by default 0
nics : int, optional
Number of initial conditions to test, by default 50
device : torch.device, optional
Device to run inference on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics, model_times, solver_times) where losses and metrics are tensors
of per-sample values, and model_times and solver_times are timing information
"""
model
.
eval
()
# make output
...
...
@@ -238,7 +293,42 @@ def train_model(
logging
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
dataloader : torch.utils.data.DataLoader
DataLoader for training data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
nepochs : int, optional
Number of training epochs, by default 20
nfuture : int, optional
Number of future steps to predict, by default 0
num_examples : int, optional
Number of examples per epoch, by default 256
num_valid : int, optional
Number of validation examples, by default 8
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start
=
time
.
time
()
# set AMP type
...
...
setup.py
View file @
e4879676
...
...
@@ -54,7 +54,22 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
warnings
.
warn
(
f
"building custom extensions skipped:
{
e
}
"
)
def
get_compile_args
(
module_name
):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
"""
Get compilation arguments based on environment variables.
If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build.
If TORCH_HARMONICS_PROFILE=1 is set, it will include profiling flags.
Parameters
-----------
module_name : str
Name of the module being compiled
Returns
-------
dict
Dictionary containing compilation flags for 'cxx' and 'nvcc' compilers
"""
debug_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_DEBUG'
,
'0'
)
==
'1'
profile_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_PROFILE'
,
'0'
)
==
'1'
...
...
@@ -77,7 +92,15 @@ def get_compile_args(module_name):
}
def
get_ext_modules
():
"""
Get list of extension modules to compile.
Returns
-------
tuple
(ext_modules, cmdclass) where ext_modules is a list of extension modules
and cmdclass is a dictionary of build commands
"""
ext_modules
=
[]
cmdclass
=
{}
...
...
tests/test_cache.py
View file @
e4879676
...
...
@@ -38,6 +38,17 @@ import torch
class
TestCacheConsistency
(
unittest
.
TestCase
):
def
test_consistency
(
self
,
verbose
=
False
):
"""
Test that cached values are not modified externally.
This test verifies that the LRU cache decorator properly handles
deep copying to prevent unintended modifications to cached objects.
Parameters
-----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
if
verbose
:
print
(
"Testing that cache values does not get modified externally"
)
from
torch_harmonics.legendre
import
_precompute_legpoly
...
...
torch_harmonics/_disco_convolution.py
View file @
e4879676
...
...
@@ -62,11 +62,31 @@ def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nl
class
_DiscoS2ContractionCuda
(
torch
.
autograd
.
Function
):
r
"""
CUDA implementation of the discrete-continuous convolution contraction on the sphere.
This class provides the forward and backward passes for efficient GPU computation
of the S2 convolution operation using custom CUDA kernels.
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
row_idx
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
vals
:
torch
.
Tensor
,
kernel_size
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
r
"""
Forward pass for CUDA S2 convolution contraction.
Parameters:
x: input tensor
roff_idx: row offset indices for sparse computation
ker_idx: kernel indices
row_idx: row indices for sparse computation
col_idx: column indices for sparse computation
vals: values for sparse computation
kernel_size: size of the kernel
nlat_out: number of output latitude points
nlon_out: number of output longitude points
"""
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
ctx
.
nlat_in
=
x
.
shape
[
-
2
]
...
...
@@ -81,6 +101,15 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for CUDA S2 convolution contraction.
Parameters:
grad_output: gradient of the output
Returns:
gradient of the input
"""
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
...
...
@@ -92,11 +121,31 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
class
_DiscoS2TransposeContractionCuda
(
torch
.
autograd
.
Function
):
r
"""
CUDA implementation of the transpose discrete-continuous convolution contraction on the sphere.
This class provides the forward and backward passes for efficient GPU computation
of the transpose S2 convolution operation using custom CUDA kernels.
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
row_idx
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
vals
:
torch
.
Tensor
,
kernel_size
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
r
"""
Forward pass for CUDA transpose S2 convolution contraction.
Parameters:
x: input tensor
roff_idx: row offset indices for sparse computation
ker_idx: kernel indices
row_idx: row indices for sparse computation
col_idx: column indices for sparse computation
vals: values for sparse computation
kernel_size: size of the kernel
nlat_out: number of output latitude points
nlon_out: number of output longitude points
"""
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
ctx
.
nlat_in
=
x
.
shape
[
-
2
]
...
...
@@ -111,6 +160,15 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for CUDA transpose S2 convolution contraction.
Parameters:
grad_output: gradient of the output
Returns:
gradient of the input
"""
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
...
...
torch_harmonics/_neighborhood_attention.py
View file @
e4879676
...
...
@@ -50,7 +50,40 @@ except ImportError as err:
def
_neighborhood_attention_s2_fwd_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
"""
Forward pass implementation of neighborhood attention on the sphere (S2).
This function computes the neighborhood attention operation using sparse tensor
operations. It implements the attention mechanism with softmax normalization
and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi) where B is batch size, C is channels,
Hi is input height (latitude), Wi is input width (longitude)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo) where Ho is output height, Wo is output width
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor with shape (B, C, Ho, Wo) after neighborhood attention computation
"""
# prepare result tensor
y
=
torch
.
zeros_like
(
qy
)
...
...
@@ -102,6 +135,41 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
def
_neighborhood_attention_s2_bwd_dv_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
"""
Backward pass implementation for value gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the value tensor (vx).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the value tensor with shape (B, C, Hi, Wi)
"""
# shapes:
# input
...
...
@@ -170,6 +238,41 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
def
_neighborhood_attention_s2_bwd_dk_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
"""
Backward pass implementation for key gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the key tensor (kx).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the key tensor with shape (B, C, Hi, Wi)
"""
# shapes:
# input
...
...
@@ -251,6 +354,41 @@ def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor,
def
_neighborhood_attention_s2_bwd_dq_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
dy
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
"""
Backward pass implementation for query gradients in neighborhood attention on S2.
This function computes the gradient of the output with respect to the query tensor (qy).
It implements the backward pass for the neighborhood attention operation using
sparse tensor operations and quadrature weights for spherical integration.
Parameters
-----------
kx : torch.Tensor
Key tensor with shape (B, C, Hi, Wi)
vx : torch.Tensor
Value tensor with shape (B, C, Hi, Wi)
qy : torch.Tensor
Query tensor with shape (B, C, Ho, Wo)
dy : torch.Tensor
Gradient of the output with shape (B, C, Ho, Wo)
quad_weights : torch.Tensor
Quadrature weights for spherical integration with shape (Hi,)
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Gradient of the query tensor with shape (B, C, Ho, Wo)
"""
# shapes:
# input
...
...
@@ -321,6 +459,11 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
return
dqy
class
_NeighborhoodAttentionS2
(
torch
.
autograd
.
Function
):
r
"""
CPU implementation of neighborhood attention on the sphere (S2).
This class provides the forward and backward passes for efficient CPU computation
of neighborhood attention operations using sparse tensor operations.
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cpu"
)
...
...
@@ -329,7 +472,27 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
r
"""
Forward pass for CPU neighborhood attention on S2.
Parameters:
k: key tensor
v: value tensor
q: query tensor
wk: key weight tensor
wv: value weight tensor
wq: query weight tensor
bk: key bias tensor (optional)
bv: value bias tensor (optional)
bq: query bias tensor (optional)
quad_weights: quadrature weights for spherical integration
col_idx: column indices for sparse computation
row_off: row offsets for sparse computation
nh: number of attention heads
nlon_in: number of input longitude points
nlat_out: number of output latitude points
nlon_out: number of output longitude points
"""
ctx
.
save_for_backward
(
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
)
ctx
.
nh
=
nh
ctx
.
nlon_in
=
nlon_in
...
...
@@ -364,6 +527,15 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
@
staticmethod
@
custom_bwd
(
device_type
=
"cpu"
)
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for CPU neighborhood attention on S2.
Parameters:
grad_output: gradient of the output
Returns:
gradients for all input tensors and parameters
"""
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
=
ctx
.
saved_tensors
nh
=
ctx
.
nh
nlon_in
=
ctx
.
nlon_in
...
...
@@ -443,13 +615,63 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
"""
Torch implementation of neighborhood attention on the sphere (S2).
This function provides a wrapper around the CPU autograd function for
neighborhood attention operations using sparse tensor computations.
Parameters
-----------
k : torch.Tensor
Key tensor
v : torch.Tensor
Value tensor
q : torch.Tensor
Query tensor
wk : torch.Tensor
Key weight tensor
wv : torch.Tensor
Value weight tensor
wq : torch.Tensor
Query weight tensor
bk : torch.Tensor or None
Key bias tensor (optional)
bv : torch.Tensor or None
Value bias tensor (optional)
bq : torch.Tensor or None
Query bias tensor (optional)
quad_weights : torch.Tensor
Quadrature weights for spherical integration
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
nh : int
Number of attention heads
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor after neighborhood attention computation
"""
return
_NeighborhoodAttentionS2
.
apply
(
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
,
quad_weights
,
col_idx
,
row_off
,
nh
,
nlon_in
,
nlat_out
,
nlon_out
)
class
_NeighborhoodAttentionS2Cuda
(
torch
.
autograd
.
Function
):
r
"""
CUDA implementation of neighborhood attention on the sphere (S2).
This class provides the forward and backward passes for efficient GPU computation
of neighborhood attention operations using custom CUDA kernels.
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
...
...
@@ -458,7 +680,28 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
max_psi_nnz
:
int
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
r
"""
Forward pass for CUDA neighborhood attention on S2.
Parameters:
k: key tensor
v: value tensor
q: query tensor
wk: key weight tensor
wv: value weight tensor
wq: query weight tensor
bk: key bias tensor (optional)
bv: value bias tensor (optional)
bq: query bias tensor (optional)
quad_weights: quadrature weights for spherical integration
col_idx: column indices for sparse computation
row_off: row offsets for sparse computation
max_psi_nnz: maximum number of non-zero elements in sparse tensor
nh: number of attention heads
nlon_in: number of input longitude points
nlat_out: number of output latitude points
nlon_out: number of output longitude points
"""
ctx
.
save_for_backward
(
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
)
ctx
.
nh
=
nh
ctx
.
max_psi_nnz
=
max_psi_nnz
...
...
@@ -499,6 +742,15 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for CUDA neighborhood attention on S2.
Parameters:
grad_output: gradient of the output
Returns:
gradients for all input tensors and parameters
"""
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
=
ctx
.
saved_tensors
nh
=
ctx
.
nh
max_psi_nnz
=
ctx
.
max_psi_nnz
...
...
@@ -584,7 +836,54 @@ def _neighborhood_attention_s2_cuda(k: torch.Tensor, v: torch.Tensor, q: torch.T
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
max_psi_nnz
:
int
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
"""
CUDA implementation of neighborhood attention on the sphere (S2).
This function provides a wrapper around the CUDA autograd function for
neighborhood attention operations using custom CUDA kernels for efficient GPU computation.
Parameters
-----------
k : torch.Tensor
Key tensor
v : torch.Tensor
Value tensor
q : torch.Tensor
Query tensor
wk : torch.Tensor
Key weight tensor
wv : torch.Tensor
Value weight tensor
wq : torch.Tensor
Query weight tensor
bk : torch.Tensor or None
Key bias tensor (optional)
bv : torch.Tensor or None
Value bias tensor (optional)
bq : torch.Tensor or None
Query bias tensor (optional)
quad_weights : torch.Tensor
Quadrature weights for spherical integration
col_idx : torch.Tensor
Column indices for sparse computation
row_off : torch.Tensor
Row offsets for sparse computation
max_psi_nnz : int
Maximum number of non-zero elements in sparse tensor
nh : int
Number of attention heads
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
Returns
-------
torch.Tensor
Output tensor after neighborhood attention computation
"""
return
_NeighborhoodAttentionS2Cuda
.
apply
(
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
,
quad_weights
,
col_idx
,
row_off
,
max_psi_nnz
,
nh
,
nlon_in
,
nlat_out
,
nlon_out
)
torch_harmonics/cache.py
View file @
e4879676
...
...
@@ -35,6 +35,32 @@ from copy import deepcopy
# copying LRU cache decorator a la:
# https://stackoverflow.com/questions/54909357/how-to-get-functools-lru-cache-to-return-new-instances
def
lru_cache
(
maxsize
=
20
,
typed
=
False
,
copy
=
False
):
"""
Least Recently Used (LRU) cache decorator with optional deep copying.
This is a wrapper around functools.lru_cache that adds the ability to return
deep copies of cached results to prevent unintended modifications to cached objects.
Parameters
-----------
maxsize : int, optional
Maximum number of items to cache, by default 20
typed : bool, optional
Whether to cache different types separately, by default False
copy : bool, optional
Whether to return deep copies of cached results, by default False
Returns
-------
function
Decorated function with LRU caching
Example
-------
>>> @lru_cache(maxsize=10, copy=True)
... def expensive_function(x):
... return [x, x*2, x*3]
"""
def
decorator
(
f
):
cached_func
=
functools
.
lru_cache
(
maxsize
=
maxsize
,
typed
=
typed
)(
f
)
def
wrapper
(
*
args
,
**
kwargs
):
...
...
torch_harmonics/distributed/distributed_resample.py
View file @
e4879676
...
...
@@ -43,6 +43,32 @@ from torch_harmonics.distributed import compute_split_shapes
class
DistributedResampleS2
(
nn
.
Module
):
r
"""
Distributed resampling module for spherical data on the 2-sphere.
This module performs distributed resampling of spherical data across multiple processes,
supporting both upscaling and downscaling operations. The data is distributed across
polar and azimuthal directions, and the module handles the necessary communication
and interpolation operations.
Parameters
-----------
nlat_in : int
Number of input latitude points
nlon_in : int
Number of input longitude points
nlat_out : int
Number of output latitude points
nlon_out : int
Number of output longitude points
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
mode : str, optional
Interpolation mode ("bilinear" or "bilinear-spherical"), by default "bilinear"
"""
def
__init__
(
self
,
nlat_in
:
int
,
...
...
@@ -133,6 +159,19 @@ class DistributedResampleS2(nn.Module):
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
"
def
_upscale_longitudes
(
self
,
x
:
torch
.
Tensor
):
"""
Upscale the longitude dimension using interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Upscaled tensor in the longitude dimension
"""
# do the interpolation
lwgt
=
self
.
lon_weights
.
to
(
x
.
dtype
)
if
self
.
mode
==
"bilinear"
:
...
...
@@ -147,6 +186,19 @@ class DistributedResampleS2(nn.Module):
return
x
def
_expand_poles
(
self
,
x
:
torch
.
Tensor
):
"""
Expand the data to include pole values for interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Tensor with expanded pole values
"""
x_north
=
x
[...,
0
,
:].
sum
(
dim
=-
1
,
keepdims
=
True
)
x_south
=
x
[...,
-
1
,
:].
sum
(
dim
=-
1
,
keepdims
=
True
)
x_count
=
torch
.
tensor
([
x
.
shape
[
-
1
]],
dtype
=
torch
.
long
,
device
=
x
.
device
,
requires_grad
=
False
)
...
...
@@ -169,6 +221,19 @@ class DistributedResampleS2(nn.Module):
return
x
def
_upscale_latitudes
(
self
,
x
:
torch
.
Tensor
):
"""
Upscale the latitude dimension using interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Upscaled tensor in the latitude dimension
"""
# do the interpolation
lwgt
=
self
.
lat_weights
.
to
(
x
.
dtype
)
if
self
.
mode
==
"bilinear"
:
...
...
@@ -183,6 +248,19 @@ class DistributedResampleS2(nn.Module):
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Forward pass for distributed resampling.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Resampled tensor with shape (batch, channels, nlat_out, nlon_out)
"""
if
self
.
skip_resampling
:
return
x
...
...
torch_harmonics/distributed/primitives.py
View file @
e4879676
...
...
@@ -95,10 +95,23 @@ def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False)
class
distributed_transpose_azimuth
(
torch
.
autograd
.
Function
):
r
"""
Distributed transpose operation for azimuthal dimension.
This class provides the forward and backward passes for distributed
tensor transposition along the azimuthal dimension.
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
,
dims
,
dim1_split_sizes
):
r
"""
Forward pass for distributed azimuthal transpose.
Parameters:
x: input tensor
dims: dimensions to transpose
dim1_split_sizes: split sizes for dimension 1
"""
# WAR for a potential contig check torch bug for channels last contig tensors
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dims
[
0
],
dims
[
1
],
dim1_split_sizes
,
group
=
azimuth_group
())
x
=
torch
.
cat
(
xlist
,
dim
=
dims
[
1
])
...
...
@@ -110,6 +123,15 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
go
):
r
"""
Backward pass for distributed azimuthal transpose.
Parameters:
go: gradient of the output
Returns:
gradient of the input
"""
dims
=
ctx
.
dims
dim0_split_sizes
=
ctx
.
dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
...
...
@@ -120,10 +142,23 @@ class distributed_transpose_azimuth(torch.autograd.Function):
class
distributed_transpose_polar
(
torch
.
autograd
.
Function
):
r
"""
Distributed transpose operation for polar dimension.
This class provides the forward and backward passes for distributed
tensor transposition along the polar dimension.
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
,
dim
,
dim1_split_sizes
):
r
"""
Forward pass for distributed polar transpose.
Parameters:
x: input tensor
dim: dimensions to transpose
dim1_split_sizes: split sizes for dimension 1
"""
# WAR for a potential contig check torch bug for channels last contig tensors
xlist
,
dim0_split_sizes
,
_
=
_transpose
(
x
,
dim
[
0
],
dim
[
1
],
dim1_split_sizes
,
group
=
polar_group
())
x
=
torch
.
cat
(
xlist
,
dim
=
dim
[
1
])
...
...
@@ -134,6 +169,15 @@ class distributed_transpose_polar(torch.autograd.Function):
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
go
):
r
"""
Backward pass for distributed polar transpose.
Parameters:
go: gradient of the output
Returns:
gradient of the input
"""
dim
=
ctx
.
dim
dim0_split_sizes
=
ctx
.
dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
...
...
@@ -244,7 +288,11 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
class
_CopyToPolarRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chunk to the rank."""
r
"""
Copy tensor to polar region for distributed computation.
This class provides the forward and backward passes for copying
tensors to the polar region in distributed settings.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
...
...
@@ -253,11 +301,29 @@ class _CopyToPolarRegion(torch.autograd.Function):
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
input_
):
r
"""
Forward pass for copying to polar region.
Parameters:
input_: input tensor
Returns:
input tensor (no-op in forward pass)
"""
return
input_
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for copying to polar region.
Parameters:
grad_output: gradient of the output
Returns:
gradient of the input
"""
if
is_distributed_polar
():
return
_reduce
(
grad_output
,
group
=
polar_group
())
else
:
...
...
@@ -265,7 +331,11 @@ class _CopyToPolarRegion(torch.autograd.Function):
class
_CopyToAzimuthRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chunk to the rank."""
r
"""
Copy tensor to azimuth region for distributed computation.
This class provides the forward and backward passes for copying
tensors to the azimuth region in distributed settings.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
...
...
@@ -274,11 +344,29 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
input_
):
r
"""
Forward pass for copying to azimuth region.
Parameters:
input_: input tensor
Returns:
input tensor (no-op in forward pass)
"""
return
input_
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for copying to azimuth region.
Parameters:
grad_output: gradient of the output
Returns:
gradient of the input
"""
if
is_distributed_azimuth
():
return
_reduce
(
grad_output
,
group
=
azimuth_group
())
else
:
...
...
@@ -286,7 +374,11 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
class
_ScatterToPolarRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chunk to the rank."""
r
"""
Scatter tensor to polar region for distributed computation.
This class provides the forward and backward passes for scattering
tensors to the polar region in distributed settings.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
,
dim_
):
...
...
torch_harmonics/examples/losses.py
View file @
e4879676
...
...
@@ -40,6 +40,27 @@ from torch_harmonics.quadrature import _precompute_latitudes
def
get_quadrature_weights
(
nlat
:
int
,
nlon
:
int
,
grid
:
str
,
tile
:
bool
=
False
,
normalized
:
bool
=
True
)
->
torch
.
Tensor
:
"""
Get quadrature weights for spherical integration.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str
Grid type ("equiangular", "legendre-gauss", "lobatto")
tile : bool, optional
Whether to tile weights across longitude dimension, by default False
normalized : bool, optional
Whether to normalize weights to sum to 1, by default True
Returns
-------
torch.Tensor
Quadrature weights tensor
"""
# area weights
_
,
q
=
_precompute_latitudes
(
nlat
=
nlat
,
grid
=
grid
)
q
=
q
.
reshape
(
-
1
,
1
)
*
2
*
torch
.
pi
/
nlon
...
...
@@ -55,6 +76,27 @@ def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False,
class
DiceLossS2
(
nn
.
Module
):
"""
Dice loss for spherical segmentation tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
mode : str, optional
Aggregation mode ("micro" or "macro"), by default "micro"
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
,
weight
:
torch
.
Tensor
=
None
,
smooth
:
float
=
0
,
ignore_index
:
int
=
-
100
,
mode
:
str
=
"micro"
):
super
().
__init__
()
...
...
@@ -73,6 +115,21 @@ class DiceLossS2(nn.Module):
self
.
register_buffer
(
"weight"
,
weight
.
unsqueeze
(
0
))
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the Dice loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
Returns
-------
torch.Tensor
Dice loss value
"""
prd
=
nn
.
functional
.
softmax
(
prd
,
dim
=
1
)
# mask values
...
...
@@ -113,6 +170,24 @@ class DiceLossS2(nn.Module):
class
CrossEntropyLossS2
(
nn
.
Module
):
"""
Cross-entropy loss for spherical classification tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Label smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
,
weight
:
torch
.
Tensor
=
None
,
smooth
:
float
=
0
,
ignore_index
:
int
=
-
100
):
...
...
@@ -130,6 +205,21 @@ class CrossEntropyLossS2(nn.Module):
self
.
register_buffer
(
"quad_weights"
,
q
)
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the cross-entropy loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
Returns
-------
torch.Tensor
Cross-entropy loss value
"""
# compute log softmax
logits
=
nn
.
functional
.
log_softmax
(
prd
,
dim
=
1
)
...
...
@@ -141,6 +231,24 @@ class CrossEntropyLossS2(nn.Module):
class
FocalLossS2
(
nn
.
Module
):
"""
Focal loss for spherical classification tasks.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
weight : torch.Tensor, optional
Class weights, by default None
smooth : float, optional
Label smoothing factor, by default 0
ignore_index : int, optional
Index to ignore in loss computation, by default -100
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
,
weight
:
torch
.
Tensor
=
None
,
smooth
:
float
=
0
,
ignore_index
:
int
=
-
100
):
...
...
@@ -158,6 +266,25 @@ class FocalLossS2(nn.Module):
self
.
register_buffer
(
"quad_weights"
,
q
)
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
,
alpha
:
float
=
0.25
,
gamma
:
float
=
2
):
"""
Forward pass of the focal loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
alpha : float, optional
Alpha parameter for focal loss, by default 0.25
gamma : float, optional
Gamma parameter for focal loss, by default 2
Returns
-------
torch.Tensor
Focal loss value
"""
# compute logits
logits
=
nn
.
functional
.
log_softmax
(
prd
,
dim
=
1
)
...
...
@@ -232,22 +359,101 @@ class SphericalLossBase(nn.Module, ABC):
class
SquaredL2LossS2
(
SphericalLossBase
):
"""
Squared L2 loss for spherical regression tasks.
Computes the squared difference between prediction and target tensors.
"""
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute squared L2 loss term.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Squared difference between prediction and target
"""
return
torch
.
square
(
prd
-
tar
)
class
L1LossS2
(
SphericalLossBase
):
"""
L1 loss for spherical regression tasks.
Computes the absolute difference between prediction and target tensors.
"""
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute L1 loss term.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Absolute difference between prediction and target
"""
return
torch
.
abs
(
prd
-
tar
)
class
L2LossS2
(
SquaredL2LossS2
):
"""
L2 loss for spherical regression tasks.
Computes the square root of the squared L2 loss.
"""
def
_post_integration_hook
(
self
,
loss
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply square root to get L2 norm.
Parameters
-----------
loss : torch.Tensor
Integrated squared loss
Returns
-------
torch.Tensor
Square root of the loss (L2 norm)
"""
return
torch
.
sqrt
(
loss
)
class
W11LossS2
(
SphericalLossBase
):
"""
W11 loss for spherical regression tasks.
Computes the L1 norm of the gradient differences between prediction and target.
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
):
"""
Initialize W11 loss.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
"""
super
().
__init__
(
nlat
=
nlat
,
nlon
=
nlon
,
grid
=
grid
)
# Set up grid and domain for FFT
l_phi
=
2
*
torch
.
pi
# domain size
...
...
@@ -305,31 +511,70 @@ class NormalLossS2(SphericalLossBase):
self
.
register_buffer
(
"k_theta_mesh"
,
k_theta_mesh
)
def
compute_gradients
(
self
,
x
):
"""
Compute gradients of the input tensor using FFT.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, nlat, nlon) or (nlat, nlon)
Returns
-------
tuple
Tuple of (grad_phi, grad_theta) gradients
"""
# Make sure x is reshaped to have a batch dimension if it's missing
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
0
)
# Add batch dimension
x_prime_fft2_phi_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
x
)).
real
x_prime_fft2_theta_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
x
)).
real
return
x_prime_fft2_theta_h
,
x_prime_fft2_phi_h
# Compute gradients using FFT
grad_phi
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
x
)).
real
grad_theta
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
x
)).
real
return
grad_phi
,
grad_theta
def
compute_normals
(
self
,
x
):
x
=
x
.
to
(
torch
.
float32
)
# Ensure x has a batch dimension
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
0
)
"""
Compute surface normals from the input tensor.
grad_lat
,
grad_lon
=
self
.
compute_gradients
(
x
)
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, nlat, nlon) or (nlat, nlon)
# Create 3D normal vectors
ones
=
torch
.
ones_like
(
x
)
normals
=
torch
.
stack
([
-
grad_lon
,
-
grad_lat
,
ones
],
dim
=
1
)
Returns
-------
torch.Tensor
Normal vectors with shape (batch, 3, nlat, nlon)
"""
grad_phi
,
grad_theta
=
self
.
compute_gradients
(
x
)
# Construct normal vectors: (-grad_theta, -grad_phi, 1)
normals
=
torch
.
stack
([
-
grad_theta
,
-
grad_phi
,
torch
.
ones_like
(
x
)],
dim
=
1
)
# Normalize
norm
=
torch
.
norm
(
normals
,
dim
=
1
,
keepdim
=
True
)
normals
=
normals
/
(
norm
+
1e-8
)
# Normalize along component dimension
normals
=
F
.
normalize
(
normals
,
p
=
2
,
dim
=
1
)
return
normals
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute combined L1 and normal consistency loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Combined loss term
"""
# Handle dimensions for both prediction and target
# Ensure we have at least a batch dimension
if
prd
.
dim
()
==
2
:
...
...
@@ -337,15 +582,18 @@ class NormalLossS2(SphericalLossBase):
if
tar
.
dim
()
==
2
:
tar
=
tar
.
unsqueeze
(
0
)
# For 4D tensors (batch, channel, height, width), remove channel if it's 1
if
prd
.
dim
()
==
4
and
prd
.
size
(
1
)
==
1
:
prd
=
prd
.
squeeze
(
1
)
if
tar
.
dim
()
==
4
and
tar
.
size
(
1
)
==
1
:
tar
=
tar
.
squeeze
(
1
)
# L1 loss term
l1_loss
=
torch
.
abs
(
prd
-
tar
)
pred_normals
=
self
.
compute_normals
(
prd
)
# Normal consistency loss
prd_normals
=
self
.
compute_normals
(
prd
)
tar_normals
=
self
.
compute_normals
(
tar
)
# Compute cosine similarity
normal_loss
=
1
-
torch
.
sum
(
pred_normals
*
tar_normals
,
dim
=
1
,
keepdim
=
True
)
return
normal_loss
# Cosine similarity between normals
cos_sim
=
torch
.
sum
(
prd_normals
*
tar_normals
,
dim
=
1
)
normal_loss
=
1
-
cos_sim
# Combine losses (equal weighting)
combined_loss
=
l1_loss
+
normal_loss
.
unsqueeze
(
1
)
return
combined_loss
torch_harmonics/examples/metrics.py
View file @
e4879676
...
...
@@ -49,6 +49,31 @@ def _get_stats_multiclass(
quad_weights
:
torch
.
Tensor
,
ignore_index
:
Optional
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute multiclass statistics (TP, FP, FN, TN) on the sphere using quadrature weights.
This function computes true positives, false positives, false negatives, and true negatives
for multiclass classification on spherical data, properly weighted by quadrature weights
to account for the spherical geometry.
Parameters
-----------
output : torch.LongTensor
Predicted class labels
target : torch.LongTensor
Ground truth class labels
num_classes : int
Number of classes in the classification task
quad_weights : torch.Tensor
Quadrature weights for spherical integration
ignore_index : Optional[int]
Index to ignore in the computation (e.g., for padding or invalid regions)
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple containing (tp_count, fp_count, fn_count, tn_count) for each class
"""
batch_size
,
*
dims
=
output
.
shape
num_elements
=
torch
.
prod
(
torch
.
tensor
(
dims
)).
long
()
...
...
@@ -88,10 +113,46 @@ def _get_stats_multiclass(
def
_predict_classes
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Convert logits to class predictions using softmax and argmax.
Parameters
-----------
logits : torch.Tensor
Input logits tensor
Returns
-------
torch.Tensor
Predicted class labels
"""
return
torch
.
argmax
(
torch
.
softmax
(
logits
,
dim
=
1
),
dim
=
1
,
keepdim
=
False
)
class
BaseMetricS2
(
nn
.
Module
):
"""
Base class for spherical metrics that properly handle spherical geometry.
This class provides the foundation for computing metrics on spherical data
by using quadrature weights to account for the non-uniform area distribution
on the sphere.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
,
weight
:
torch
.
Tensor
=
None
,
ignore_index
:
int
=
-
100
,
mode
:
str
=
"micro"
):
super
().
__init__
()
...
...
@@ -108,6 +169,21 @@ class BaseMetricS2(nn.Module):
self
.
register_buffer
(
"weight"
,
weight
.
unsqueeze
(
0
))
def
_forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute base statistics (TP, FP, FN, TN) for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple containing (tp, fp, fn, tn) statistics
"""
# convert logits to class predictions
pred_class
=
_predict_classes
(
pred
)
...
...
@@ -138,11 +214,47 @@ class BaseMetricS2(nn.Module):
class
IntersectionOverUnionS2
(
BaseMetricS2
):
"""
Intersection over Union (IoU) metric for spherical data.
Computes the IoU score for multiclass classification on the sphere,
properly weighted by quadrature weights to account for spherical geometry.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
,
weight
:
torch
.
Tensor
=
None
,
ignore_index
:
int
=
-
100
,
mode
:
str
=
"micro"
):
super
().
__init__
(
nlat
,
nlon
,
grid
,
weight
,
ignore_index
,
mode
)
def
forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute IoU score for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
torch.Tensor
IoU score
"""
tp
,
fp
,
fn
,
tn
=
self
.
_forward
(
pred
,
truth
)
# compute score
...
...
@@ -162,11 +274,47 @@ class IntersectionOverUnionS2(BaseMetricS2):
class
AccuracyS2
(
BaseMetricS2
):
"""
Accuracy metric for spherical data.
Computes the accuracy score for multiclass classification on the sphere,
properly weighted by quadrature weights to account for spherical geometry.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type ("equiangular", "legendre-gauss", etc.), by default "equiangular"
weight : torch.Tensor, optional
Class weights for weighted averaging, by default None
ignore_index : int, optional
Index to ignore in computations, by default -100
mode : str, optional
Averaging mode ("micro" or "macro"), by default "micro"
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
,
weight
:
torch
.
Tensor
=
None
,
ignore_index
:
int
=
-
100
,
mode
:
str
=
"micro"
):
super
().
__init__
(
nlat
,
nlon
,
grid
,
weight
,
ignore_index
,
mode
)
def
forward
(
self
,
pred
:
torch
.
Tensor
,
truth
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute accuracy score for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
torch.Tensor
Accuracy score
"""
tp
,
fp
,
fn
,
tn
=
self
.
_forward
(
pred
,
truth
)
# compute score
...
...
torch_harmonics/examples/models/_layers.py
View file @
e4879676
This diff is collapsed.
Click to expand it.
torch_harmonics/examples/models/lsno.py
View file @
e4879676
...
...
@@ -50,6 +50,35 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return
(
kernel_shape
[
0
]
+
1
)
*
theta_cutoff_factor
[
basis_type
]
*
math
.
pi
/
float
(
nlat
-
1
)
class
DiscreteContinuousEncoder
(
nn
.
Module
):
r
"""
Discrete-continuous encoder for spherical neural operators.
This module performs downsampling using discrete-continuous convolutions on the sphere,
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
inp_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
"""
def
__init__
(
self
,
in_shape
=
(
721
,
1440
),
...
...
@@ -81,6 +110,19 @@ class DiscreteContinuousEncoder(nn.Module):
)
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous encoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Encoded tensor with reduced spatial resolution
"""
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
...
@@ -92,6 +134,37 @@ class DiscreteContinuousEncoder(nn.Module):
class
DiscreteContinuousDecoder
(
nn
.
Module
):
r
"""
Discrete-continuous decoder for spherical neural operators.
This module performs upsampling using either spherical harmonic transforms or resampling,
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (721, 1440)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
inp_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsample_sht : bool, optional
Whether to use SHT for upsampling, by default False
"""
def
__init__
(
self
,
in_shape
=
(
480
,
960
),
...
...
@@ -132,6 +205,19 @@ class DiscreteContinuousDecoder(nn.Module):
)
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous decoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Decoded tensor with restored spatial resolution
"""
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
...
@@ -274,7 +360,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
class
LocalSphericalNeuralOperator
(
nn
.
Module
):
"""
r
"""
LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
...
...
@@ -282,43 +368,48 @@ class LocalSphericalNeuralOperator(nn.Module):
Parameters
-----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
img_size : tuple, optional
Input image size (nlat, nlon), by default (128, 256)
grid : str, optional
Grid type for input/output, by default "equiangular"
grid_internal : str, optional
Grid type for internal processing, by default "legendre-gauss"
scale_factor : int, optional
Scale factor
to use
, by default 3
Scale factor
for resolution changes
, by default 3
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
embed_dim : int, optional
Dimension of the embeddings
, by default 256
Embedding dimension
, by default 256
num_layers : int, optional
Number of layers
in the network
, by default 4
Number of layers, by default 4
activation_function : str, optional
Activation function to use, by default "gelu"
encoder_kernel_shape : int, optional
size of the encoder kernel
filter_basis_type: Optional[str]: str, optional
filter basis type
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
Activation function name, by default "gelu"
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
encoder_kernel_shape : tuple, optional
Kernel shape for encoder, by default (3, 3)
filter_basis_type : str, optional
Filter basis type, by default "morlet"
use_mlp : bool, optional
Whether to use MLP layers, by default True
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Drop
out
path rate, by default 0.0
Drop path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
sfno_block_frequency : int, optional
Hopw often a (global)
SFNO block
is used
, by default 2
Frequency of
SFNO block
s
, by default 2
hard_thresholding_fraction : float, optional
Fraction of h
ard thresholding
(
fr
equency cutoff) to apply
, by default 1.0
big_skip
: bool, optional
Whether to
add a single large skip conne
ction, by default
Tru
e
pos_embed :
bool
, optional
Whether to use p
osition
al
embedding, by default
True
H
ard thresholding fr
action
, by default 1.0
residual_prediction
: bool, optional
Whether to
use residual predi
ction, by default
Fals
e
pos_embed :
str
, optional
P
osition embedding
type
, by default
"none"
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
bias : bool, optional
...
...
@@ -497,6 +588,19 @@ class LocalSphericalNeuralOperator(nn.Module):
return
x
def
forward
(
self
,
x
):
"""
Forward pass through the complete LSNO model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if
self
.
residual_prediction
:
residual
=
x
...
...
torch_harmonics/examples/models/s2segformer.py
View file @
e4879676
...
...
@@ -54,6 +54,34 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
class
OverlapPatchMerging
(
nn
.
Module
):
"""
Overlap patch merging module for spherical segformer.
This module performs patch merging with overlapping patches using discrete-continuous
convolutions on the sphere, followed by layer normalization.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (481, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_channels : int, optional
Number of input channels, by default 3
out_channels : int, optional
Number of output channels, by default 64
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
bias : bool, optional
Whether to use bias, by default False
"""
def
__init__
(
self
,
in_shape
=
(
721
,
1440
),
...
...
@@ -89,11 +117,32 @@ class OverlapPatchMerging(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
"""
Forward pass of the overlap patch merging module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Merged patches with layer normalization
"""
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
...
@@ -109,6 +158,38 @@ class OverlapPatchMerging(nn.Module):
class
MixFFN
(
nn
.
Module
):
"""
Mix FFN module for spherical segformer.
This module implements a feed-forward network that combines MLP operations
with discrete-continuous convolutions on the sphere.
Parameters
-----------
shape : tuple
Shape (nlat, nlon) of the input
inout_channels : int
Number of input/output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
grid : str, optional
Grid type, by default "equiangular"
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : nn.Module, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
drop_path : float, optional
Drop path rate, by default 0.0
"""
def
__init__
(
self
,
shape
,
...
...
@@ -161,6 +242,14 @@ class MixFFN(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
...
...
@@ -170,7 +259,19 @@ class MixFFN(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the Mix FFN module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after Mix FFN processing
"""
residual
=
x
# norm
...
...
@@ -194,6 +295,35 @@ class MixFFN(nn.Module):
class
AttentionWrapper
(
nn
.
Module
):
"""
Attention wrapper for spherical segformer.
This module wraps attention mechanisms (neighborhood or global) with optional
normalization and drop path regularization.
Parameters
-----------
channels : int
Number of channels
shape : tuple
Shape (nlat, nlon) of the input
grid : str
Grid type
heads : int
Number of attention heads
pre_norm : bool, optional
Whether to apply normalization before attention, by default False
attention_drop_rate : float, optional
Dropout rate for attention, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
channels
,
...
...
@@ -252,7 +382,19 @@ class AttentionWrapper(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the attention wrapper.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after attention processing
"""
residual
=
x
if
self
.
norm
is
not
None
:
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
...
...
@@ -271,6 +413,49 @@ class AttentionWrapper(nn.Module):
class
TransformerBlock
(
nn
.
Module
):
"""
Transformer block for spherical segformer.
This block combines patch merging, attention, and Mix FFN operations
in a hierarchical structure for processing spherical data.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of repetitions, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.GELU
att_drop_rate : float, optional
Dropout rate for attention, by default 0.0
drop_path_rates : float, optional
Drop path rates, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -363,6 +548,19 @@ class TransformerBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the transformer block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after transformer block processing
"""
x
=
self
.
fwd
(
x
)
# apply norm
...
...
@@ -374,6 +572,43 @@ class TransformerBlock(nn.Module):
class
Upsampling
(
nn
.
Module
):
"""
Upsampling module for spherical segformer.
This module performs upsampling using either discrete-continuous transposed convolutions
or bilinear resampling on spherical data.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : nn.Module, optional
Activation function, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP instead of linear layers, by default False
upsampling_method : str, optional
Upsampling method ("conv" or "bilinear"), by default "conv"
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -429,6 +664,19 @@ class Upsampling(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the upsampling module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Upsampled tensor
"""
x
=
self
.
upsample
(
self
.
mlp
(
x
))
return
x
...
...
@@ -606,6 +854,14 @@ class SphericalSegformer(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
...
...
@@ -615,7 +871,19 @@ class SphericalSegformer(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
"""
Forward pass through the complete spherical segformer model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder:
features
=
[]
feat
=
x
...
...
torch_harmonics/examples/models/s2transformer.py
View file @
e4879676
...
...
@@ -52,6 +52,36 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return
(
kernel_shape
[
0
]
+
1
)
*
theta_cutoff_factor
[
basis_type
]
*
math
.
pi
/
float
(
nlat
-
1
)
class
DiscreteContinuousEncoder
(
nn
.
Module
):
"""
Discrete-continuous encoder for spherical transformers.
This module performs downsampling using discrete-continuous convolutions on the sphere,
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
"""
def
__init__
(
self
,
in_shape
=
(
721
,
1440
),
...
...
@@ -83,6 +113,19 @@ class DiscreteContinuousEncoder(nn.Module):
)
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous encoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Encoded tensor with reduced spatial resolution
"""
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
...
@@ -94,6 +137,38 @@ class DiscreteContinuousEncoder(nn.Module):
class
DiscreteContinuousDecoder
(
nn
.
Module
):
"""
Discrete-continuous decoder for spherical transformers.
This module performs upsampling using either spherical harmonic transforms or resampling,
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (721, 1440)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
groups : int, optional
Number of groups for grouped convolution, by default 1
bias : bool, optional
Whether to use bias, by default False
upsample_sht : bool, optional
Whether to use SHT for upsampling, by default False
"""
def
__init__
(
self
,
in_shape
=
(
480
,
960
),
...
...
@@ -134,6 +209,19 @@ class DiscreteContinuousDecoder(nn.Module):
)
def
forward
(
self
,
x
):
"""
Forward pass of the discrete-continuous decoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Decoded tensor with restored spatial resolution
"""
dtype
=
x
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
...
...
@@ -147,7 +235,45 @@ class DiscreteContinuousDecoder(nn.Module):
class
SphericalAttentionBlock
(
nn
.
Module
):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Spherical attention block for transformers on the sphere.
This module implements a single attention block that can use either global attention
or neighborhood attention on spherical data, followed by an optional MLP.
Parameters
-----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
Output shape (nlat, nlon), by default (480, 960)
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
in_chans : int, optional
Number of input channels, by default 2
out_chans : int, optional
Number of output channels, by default 2
num_heads : int, optional
Number of attention heads, by default 1
mlp_ratio : float, optional
Ratio of MLP hidden dimension to output dimension, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : nn.Module, optional
Activation layer, by default nn.GELU
norm_layer : str, optional
Normalization layer type, by default "none"
use_mlp : bool, optional
Whether to use MLP after attention, by default True
bias : bool, optional
Whether to use bias, by default False
attention_mode : str, optional
Attention mode ("neighborhood" or "global"), by default "neighborhood"
theta_cutoff : float, optional
Cutoff radius for neighborhood attention, by default None
"""
def
__init__
(
...
...
@@ -467,6 +593,19 @@ class SphericalTransformer(nn.Module):
return
x
def
forward
(
self
,
x
):
"""
Forward pass through the complete spherical transformer model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if
self
.
residual_prediction
:
residual
=
x
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment