You need to sign in or sign up before continuing.
Commit 290da8e0 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in examples

parent 6f3250cb
......@@ -371,6 +371,41 @@ class AttentionWrapper(nn.Module):
class TransformerBlock(nn.Module):
"""
Transformer block with attention and MLP.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
nrep : int, optional
Number of repetitions of attention and MLP blocks, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (3, 3)
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
att_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path_rates : float or list, optional
Drop path rates for each block, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def __init__(
self,
in_shape,
......@@ -493,6 +528,33 @@ class TransformerBlock(nn.Module):
class Upsampling(nn.Module):
"""
Upsampling block for the Segformer model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP, by default False
"""
def __init__(
self,
in_shape,
......@@ -534,7 +596,7 @@ class Segformer(nn.Module):
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
-----------
----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
......@@ -566,7 +628,7 @@ class Segformer(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
-----------
----------
>>> model = Segformer(
... img_size=(128, 256),
... in_chans=3,
......
......@@ -43,6 +43,37 @@ from functools import partial
class DownsamplingBlock(nn.Module):
"""
Downsampling block for the UNet model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__(
self,
in_shape,
......
......@@ -67,7 +67,7 @@ def count_parameters(model):
Count the number of trainable parameters in a model.
Parameters
-----------
----------
model : torch.nn.Module
The model to count parameters for
......@@ -87,7 +87,7 @@ def log_weights_and_grads(exp_dir, model, iters=1):
Saves model weights and gradients to a file for analysis.
Parameters
-----------
----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
......
......@@ -49,7 +49,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
computational budgets and performance requirements.
Parameters
-----------
----------
img_size : tuple, optional
Input image size as (height, width), by default (128, 256)
in_chans : int, optional
......@@ -64,7 +64,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular"
Returns
-------
----------
dict
Dictionary mapping model names to partial functions that can be called
to instantiate the corresponding model with the specified parameters.
......@@ -96,13 +96,13 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
- vit_sc2_layers4_e128: Vision Transformer variant (medium)
Examples
--------
----------
>>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1)
>>> model = model_registry['sfno_sc2_layers4_e32']()
>>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Notes
-----
----------
- Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim}
- 'sc2' indicates scale factor of 2 (downsampling by 2)
- 'e32', 'e128', 'e256' indicate embedding dimensions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment