Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
290da8e0
You need to sign in or sign up before continuing.
Commit
290da8e0
authored
Jun 30, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Improved docstrings in examples
parent
6f3250cb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
8 deletions
+101
-8
examples/baseline_models/segformer.py
examples/baseline_models/segformer.py
+64
-2
examples/baseline_models/unet.py
examples/baseline_models/unet.py
+31
-0
examples/depth/train.py
examples/depth/train.py
+2
-2
examples/model_registry.py
examples/model_registry.py
+4
-4
No files found.
examples/baseline_models/segformer.py
View file @
290da8e0
...
...
@@ -371,6 +371,41 @@ class AttentionWrapper(nn.Module):
class
TransformerBlock
(
nn
.
Module
):
"""
Transformer block with attention and MLP.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
nrep : int, optional
Number of repetitions of attention and MLP blocks, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (3, 3)
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
att_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path_rates : float or list, optional
Drop path rates for each block, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -493,6 +528,33 @@ class TransformerBlock(nn.Module):
class
Upsampling
(
nn
.
Module
):
"""
Upsampling block for the Segformer model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP, by default False
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -534,7 +596,7 @@ class Segformer(nn.Module):
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
----------
-
----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
...
...
@@ -566,7 +628,7 @@ class Segformer(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
----------
-
----------
>>> model = Segformer(
... img_size=(128, 256),
... in_chans=3,
...
...
examples/baseline_models/unet.py
View file @
290da8e0
...
...
@@ -43,6 +43,37 @@ from functools import partial
class
DownsamplingBlock
(
nn
.
Module
):
"""
Downsampling block for the UNet model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def
__init__
(
self
,
in_shape
,
...
...
examples/depth/train.py
View file @
290da8e0
...
...
@@ -67,7 +67,7 @@ def count_parameters(model):
Count the number of trainable parameters in a model.
Parameters
----------
-
----------
model : torch.nn.Module
The model to count parameters for
...
...
@@ -87,7 +87,7 @@ def log_weights_and_grads(exp_dir, model, iters=1):
Saves model weights and gradients to a file for analysis.
Parameters
----------
-
----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
...
...
examples/model_registry.py
View file @
290da8e0
...
...
@@ -49,7 +49,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
computational budgets and performance requirements.
Parameters
----------
-
----------
img_size : tuple, optional
Input image size as (height, width), by default (128, 256)
in_chans : int, optional
...
...
@@ -64,7 +64,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular"
Returns
-------
-------
---
dict
Dictionary mapping model names to partial functions that can be called
to instantiate the corresponding model with the specified parameters.
...
...
@@ -96,13 +96,13 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
- vit_sc2_layers4_e128: Vision Transformer variant (medium)
Examples
--------
--------
--
>>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1)
>>> model = model_registry['sfno_sc2_layers4_e32']()
>>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Notes
-----
-----
-----
- Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim}
- 'sc2' indicates scale factor of 2 (downsampling by 2)
- 'e32', 'e128', 'e256' indicate embedding dimensions
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment