Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
b17bfdc4
Commit
b17bfdc4
authored
Jul 16, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
removed docstrings from examples
parent
901e8635
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
18 additions
and
526 deletions
+18
-526
examples/baseline_models/segformer.py
examples/baseline_models/segformer.py
+3
-50
examples/baseline_models/transformer.py
examples/baseline_models/transformer.py
+4
-52
examples/baseline_models/unet.py
examples/baseline_models/unet.py
+3
-47
examples/depth/train.py
examples/depth/train.py
+2
-101
examples/model_registry.py
examples/model_registry.py
+0
-70
examples/segmentation/train.py
examples/segmentation/train.py
+3
-98
examples/shallow_water_equations/train.py
examples/shallow_water_equations/train.py
+2
-95
torch_harmonics/resample.py
torch_harmonics/resample.py
+1
-13
No files found.
examples/baseline_models/segformer.py
View file @
b17bfdc4
...
...
@@ -90,14 +90,6 @@ class OverlapPatchMerging(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
...
...
@@ -174,14 +166,6 @@ class MixFFN(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
...
...
@@ -238,19 +222,7 @@ class GlobalAttention(nn.Module):
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
=
chans
,
num_heads
=
num_heads
,
dropout
=
dropout
,
batch_first
=
True
,
bias
=
bias
)
def
forward
(
self
,
x
):
"""
Forward pass through the GlobalAttention module.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (B, C, H, W)
Returns
-------
torch.Tensor
Output tensor of shape (B, C, H, W)
"""
# x: B, C, H, W
B
,
H
,
W
,
C
=
x
.
shape
# flatten spatial dims
...
...
@@ -309,32 +281,13 @@ class AttentionWrapper(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass through the AttentionWrapper.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with residual connection
"""
residual
=
x
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
if
self
.
norm
is
not
None
:
...
...
examples/baseline_models/transformer.py
View file @
b17bfdc4
...
...
@@ -57,19 +57,7 @@ class Encoder(nn.Module):
self
.
conv
=
nn
.
Conv2d
(
in_chans
,
out_chans
,
kernel_size
=
kernel_shape
,
bias
=
bias
,
stride
=
(
stride_h
,
stride_w
),
padding
=
(
pad_h
,
pad_w
),
groups
=
groups
)
def
forward
(
self
,
x
):
"""
Forward pass through the Encoder layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after encoding
"""
x
=
self
.
conv
(
x
)
return
x
...
...
@@ -122,19 +110,7 @@ class Decoder(nn.Module):
raise
ValueError
(
f
"Unknown upsampling method
{
upsampling_method
}
"
)
def
forward
(
self
,
x
):
"""
Forward pass through the Decoder layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after decoding and upsampling
"""
x
=
self
.
upsample
(
x
)
return
x
...
...
@@ -163,19 +139,7 @@ class GlobalAttention(nn.Module):
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
=
chans
,
num_heads
=
num_heads
,
dropout
=
dropout
,
batch_first
=
True
,
bias
=
bias
)
def
forward
(
self
,
x
):
"""
Forward pass through the GlobalAttention module.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (B, C, H, W)
Returns
-------
torch.Tensor
Output tensor of shape (B, C, H, W)
"""
# x: B, C, H, W
B
,
H
,
W
,
C
=
x
.
shape
# flatten spatial dims
...
...
@@ -286,19 +250,7 @@ class AttentionBlock(nn.Module):
self
.
skip1
=
nn
.
Identity
()
def
forward
(
self
,
x
):
"""
Forward pass through the AttentionBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with residual connections
"""
residual
=
x
x
=
self
.
norm0
(
x
)
...
...
examples/baseline_models/unet.py
View file @
b17bfdc4
...
...
@@ -185,19 +185,7 @@ class DownsamplingBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass through the DownsamplingBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after downsampling
"""
# skip connection
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
...
...
@@ -370,19 +358,7 @@ class UpsamplingBlock(nn.Module):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass through the UpsamplingBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after upsampling
"""
# skip connection
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
...
...
@@ -545,14 +521,6 @@ class UNet(nn.Module):
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
m
.
bias
is
not
None
:
...
...
@@ -563,19 +531,7 @@ class UNet(nn.Module):
def
forward
(
self
,
x
):
"""
Forward pass through the UNet model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, num_classes, height, width)
"""
# encoder:
features
=
[]
feat
=
x
...
...
examples/depth/train.py
View file @
b17bfdc4
...
...
@@ -63,38 +63,11 @@ import wandb
# helper routine for counting number of paramerters in model
def
count_parameters
(
model
):
"""
Count the number of trainable parameters in a model.
Parameters
----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
exp_dir
,
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
log_path
=
os
.
path
.
join
(
exp_dir
,
"weights_and_grads"
)
if
not
os
.
path
.
isdir
(
log_path
):
os
.
makedirs
(
log_path
,
exist_ok
=
True
)
...
...
@@ -121,36 +94,7 @@ def validate_model(
logging
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
"""
Validate a model on a dataset and compute metrics.
Parameters
-----------
model : torch.nn.Module
Model to validate
dataloader : torch.utils.data.DataLoader
DataLoader for validation data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving validation outputs
normalization_in : callable, optional
Normalization function to apply to inputs, by default None
normalization_out : callable, optional
Normalization function to apply to targets, by default None
logging : bool, optional
Whether to save validation plots, by default True
device : torch.device, optional
Device to run validation on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics) where losses is a tensor of per-sample losses
and metrics is a dict of per-sample metric values
"""
model
.
eval
()
num_examples
=
len
(
dataloader
)
...
...
@@ -243,50 +187,7 @@ def train_model(
logging
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
train_dataloader : torch.utils.data.DataLoader
DataLoader for training data
train_sampler : torch.utils.data.Sampler
Sampler for training data
test_dataloader : torch.utils.data.DataLoader
DataLoader for test data
test_sampler : torch.utils.data.Sampler
Sampler for test data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
normalization_in : callable, optional
Normalization function to apply to inputs, by default None
normalization_out : callable, optional
Normalization function to apply to targets, by default None
augmentation : bool, optional
Whether to apply data augmentation, by default False
nepochs : int, optional
Number of training epochs, by default 20
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
exp_dir : str, optional
Experiment directory for logging, by default None
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start
=
time
.
time
()
# set AMP type
...
...
examples/model_registry.py
View file @
b17bfdc4
...
...
@@ -39,77 +39,7 @@ from baseline_models import Transformer, UNet, Segformer
from
torch_harmonics.examples.models
import
SphericalFourierNeuralOperator
,
LocalSphericalNeuralOperator
,
SphericalTransformer
,
SphericalUNet
,
SphericalSegformer
def
get_baseline_models
(
img_size
=
(
128
,
256
),
in_chans
=
3
,
out_chans
=
3
,
residual_prediction
=
False
,
drop_path_rate
=
0.
,
grid
=
"equiangular"
):
"""
Get a registry of baseline models for spherical and planar neural networks.
This function returns a dictionary containing pre-configured model architectures
for various tasks including spherical Fourier neural operators (SFNO), local
spherical neural operators (LSNO), spherical transformers, U-Nets, and Segformers.
Each model is configured with specific hyperparameters optimized for different
computational budgets and performance requirements.
Parameters
----------
img_size : tuple, optional
Input image size as (height, width), by default (128, 256)
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
residual_prediction : bool, optional
Whether to use residual prediction (add input to output), by default False
drop_path_rate : float, optional
Dropout path rate for regularization, by default 0.0
grid : str, optional
Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular"
Returns
----------
dict
Dictionary mapping model names to partial functions that can be called
to instantiate the corresponding model with the specified parameters.
Available models include:
**Spherical Models:**
- sfno_sc2_layers4_e32: Spherical Fourier Neural Operator (small)
- lsno_sc2_layers4_e32: Local Spherical Neural Operator (small)
- s2unet_sc2_layers4_e128: Spherical U-Net (medium)
- s2transformer_sc2_layers4_e128: Spherical Transformer (global attention, medium)
- s2transformer_sc2_layers4_e256: Spherical Transformer (global attention, large)
- s2ntransformer_sc2_layers4_e128: Spherical Transformer (neighborhood attention, medium)
- s2ntransformer_sc2_layers4_e256: Spherical Transformer (neighborhood attention, large)
- s2segformer_sc2_layers4_e128: Spherical Segformer (global attention, medium)
- s2segformer_sc2_layers4_e256: Spherical Segformer (global attention, large)
- s2nsegformer_sc2_layers4_e128: Spherical Segformer (neighborhood attention, medium)
- s2nsegformer_sc2_layers4_e256: Spherical Segformer (neighborhood attention, large)
**Planar Models:**
- transformer_sc2_layers4_e128: Planar Transformer (global attention, medium)
- transformer_sc2_layers4_e256: Planar Transformer (global attention, large)
- ntransformer_sc2_layers4_e128: Planar Transformer (neighborhood attention, medium)
- ntransformer_sc2_layers4_e256: Planar Transformer (neighborhood attention, large)
- segformer_sc2_layers4_e128: Planar Segformer (global attention, medium)
- segformer_sc2_layers4_e256: Planar Segformer (global attention, large)
- nsegformer_sc2_layers4_e128: Planar Segformer (neighborhood attention, medium)
- nsegformer_sc2_layers4_e256: Planar Segformer (neighborhood attention, large)
- vit_sc2_layers4_e128: Vision Transformer variant (medium)
Examples
----------
>>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1)
>>> model = model_registry['sfno_sc2_layers4_e32']()
>>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Notes
----------
- Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim}
- 'sc2' indicates scale factor of 2 (downsampling by 2)
- 'e32', 'e128', 'e256' indicate embedding dimensions
- 'n' prefix indicates neighborhood attention instead of global attention
- All models use GELU activation and instance normalization by default
"""
# prepare dicts containing models and corresponding metrics
model_registry
=
dict
(
sfno_sc2_layers4_e32
=
partial
(
...
...
examples/segmentation/train.py
View file @
b17bfdc4
...
...
@@ -68,38 +68,13 @@ import wandb
# helper routine for counting number of paramerters in model
def
count_parameters
(
model
):
"""
Count the number of trainable parameters in a model.
Parameters
-----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
exp_dir
,
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
-----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
log_path
=
os
.
path
.
join
(
exp_dir
,
"weights_and_grads"
)
if
not
os
.
path
.
isdir
(
log_path
):
os
.
makedirs
(
log_path
,
exist_ok
=
True
)
...
...
@@ -116,34 +91,7 @@ def log_weights_and_grads(exp_dir, model, iters=1):
# rolls out the FNO and compares to the classical solver
def
validate_model
(
model
,
dataloader
,
loss_fn
,
metrics_fns
,
path_root
,
normalization
=
None
,
logging
=
True
,
device
=
torch
.
device
(
"cpu"
)):
"""
Validate a model on a dataset and compute metrics.
Parameters
-----------
model : torch.nn.Module
Model to validate
dataloader : torch.utils.data.DataLoader
DataLoader for validation data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving validation outputs
normalization : callable, optional
Normalization function to apply to inputs, by default None
logging : bool, optional
Whether to save validation plots, by default True
device : torch.device, optional
Device to run validation on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics) where losses is a tensor of per-sample losses
and metrics is a dict of per-sample metric values
"""
model
.
eval
()
num_examples
=
len
(
dataloader
)
...
...
@@ -229,50 +177,7 @@ def train_model(
logging
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
train_dataloader : torch.utils.data.DataLoader
DataLoader for training data
train_sampler : torch.utils.data.Sampler
Sampler for training data
test_dataloader : torch.utils.data.DataLoader
DataLoader for test data
test_sampler : torch.utils.data.Sampler
Sampler for test data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
max_grad_norm : float, optional
Maximum gradient norm for clipping, by default 0.0
normalization : callable, optional
Normalization function to apply to inputs, by default None
augmentation : callable, optional
Augmentation function to apply to inputs, by default None
nepochs : int, optional
Number of training epochs, by default 20
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
exp_dir : str, optional
Experiment directory for logging, by default None
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start
=
time
.
time
()
# set AMP type
...
...
examples/shallow_water_equations/train.py
View file @
b17bfdc4
...
...
@@ -63,36 +63,11 @@ except:
# helper routine for counting number of paramerters in model
def
count_parameters
(
model
):
"""
Count the number of trainable parameters in a model.
Parameters
-----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
# convenience function for logging weights and gradients
def
log_weights_and_grads
(
model
,
iters
=
1
):
"""
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
-----------
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
root_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"weights_and_grads"
)
weights_and_grads_fname
=
os
.
path
.
join
(
root_path
,
f
"weights_and_grads_step
{
iters
:
03
d
}
.tar"
)
...
...
@@ -119,40 +94,7 @@ def autoregressive_inference(
nics
=
50
,
device
=
torch
.
device
(
"cpu"
),
):
"""
Perform autoregressive inference with a trained model and compare to classical solver.
Parameters
-----------
model : torch.nn.Module
Trained model to evaluate
dataset : torch.utils.data.Dataset
Dataset containing solver and normalization parameters
loss_fn : callable
Loss function for evaluation
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving inference outputs
nsteps : int
Number of solver steps per autoregressive step
autoreg_steps : int, optional
Number of autoregressive steps, by default 10
nskip : int, optional
Skip interval for plotting, by default 1
plot_channel : int, optional
Channel to plot, by default 0
nics : int, optional
Number of initial conditions to test, by default 50
device : torch.device, optional
Device to run inference on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics, model_times, solver_times) where losses and metrics are tensors
of per-sample values, and model_times and solver_times are timing information
"""
model
.
eval
()
# make output
...
...
@@ -293,42 +235,7 @@ def train_model(
logging
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
dataloader : torch.utils.data.DataLoader
DataLoader for training data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
nepochs : int, optional
Number of training epochs, by default 20
nfuture : int, optional
Number of future steps to predict, by default 0
num_examples : int, optional
Number of examples per epoch, by default 256
num_valid : int, optional
Number of validation examples, by default 8
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start
=
time
.
time
()
# set AMP type
...
...
torch_harmonics/resample.py
View file @
b17bfdc4
...
...
@@ -219,19 +219,7 @@ class ResampleS2(nn.Module):
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Forward pass of the resampling module.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat_in, nlon_in)
Returns
-------
torch.Tensor
Resampled tensor with shape (..., nlat_out, nlon_out)
"""
if
self
.
skip_resampling
:
return
x
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment