Commit b17bfdc4 authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

removed docstrings from examples

parent 901e8635
......@@ -90,14 +90,6 @@ class OverlapPatchMerging(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
......@@ -174,14 +166,6 @@ class MixFFN(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
......@@ -238,19 +222,7 @@ class GlobalAttention(nn.Module):
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x):
"""
Forward pass through the GlobalAttention module.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (B, C, H, W)
Returns
-------
torch.Tensor
Output tensor of shape (B, C, H, W)
"""
# x: B, C, H, W
B, H, W, C = x.shape
# flatten spatial dims
......@@ -309,32 +281,13 @@ class AttentionWrapper(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the AttentionWrapper.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with residual connection
"""
residual = x
x = x.permute(0, 2, 3, 1)
if self.norm is not None:
......
......@@ -57,19 +57,7 @@ class Encoder(nn.Module):
self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_shape, bias=bias, stride=(stride_h, stride_w), padding=(pad_h, pad_w), groups=groups)
def forward(self, x):
"""
Forward pass through the Encoder layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after encoding
"""
x = self.conv(x)
return x
......@@ -122,19 +110,7 @@ class Decoder(nn.Module):
raise ValueError(f"Unknown upsampling method {upsampling_method}")
def forward(self, x):
"""
Forward pass through the Decoder layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after decoding and upsampling
"""
x = self.upsample(x)
return x
......@@ -163,19 +139,7 @@ class GlobalAttention(nn.Module):
self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)
def forward(self, x):
"""
Forward pass through the GlobalAttention module.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (B, C, H, W)
Returns
-------
torch.Tensor
Output tensor of shape (B, C, H, W)
"""
# x: B, C, H, W
B, H, W, C = x.shape
# flatten spatial dims
......@@ -286,19 +250,7 @@ class AttentionBlock(nn.Module):
self.skip1 = nn.Identity()
def forward(self, x):
"""
Forward pass through the AttentionBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with residual connections
"""
residual = x
x = self.norm0(x)
......
......@@ -185,19 +185,7 @@ class DownsamplingBlock(nn.Module):
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the DownsamplingBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after downsampling
"""
# skip connection
residual = x
if hasattr(self, "transform_skip"):
......@@ -370,19 +358,7 @@ class UpsamplingBlock(nn.Module):
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the UpsamplingBlock.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after upsampling
"""
# skip connection
residual = x
if hasattr(self, "transform_skip"):
......@@ -545,14 +521,6 @@ class UNet(nn.Module):
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Initialize weights for the module.
Parameters
-----------
m : torch.nn.Module
Module to initialize weights for
"""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
......@@ -563,19 +531,7 @@ class UNet(nn.Module):
def forward(self, x):
"""
Forward pass through the UNet model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, num_classes, height, width)
"""
# encoder:
features = []
feat = x
......
......@@ -63,38 +63,11 @@ import wandb
# helper routine for counting number of paramerters in model
def count_parameters(model):
"""
Count the number of trainable parameters in a model.
Parameters
----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# convenience function for logging weights and gradients
def log_weights_and_grads(exp_dir, model, iters=1):
"""
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
log_path = os.path.join(exp_dir, "weights_and_grads")
if not os.path.isdir(log_path):
os.makedirs(log_path, exist_ok=True)
......@@ -121,36 +94,7 @@ def validate_model(
logging=True,
device=torch.device("cpu"),
):
"""
Validate a model on a dataset and compute metrics.
Parameters
-----------
model : torch.nn.Module
Model to validate
dataloader : torch.utils.data.DataLoader
DataLoader for validation data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving validation outputs
normalization_in : callable, optional
Normalization function to apply to inputs, by default None
normalization_out : callable, optional
Normalization function to apply to targets, by default None
logging : bool, optional
Whether to save validation plots, by default True
device : torch.device, optional
Device to run validation on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics) where losses is a tensor of per-sample losses
and metrics is a dict of per-sample metric values
"""
model.eval()
num_examples = len(dataloader)
......@@ -243,50 +187,7 @@ def train_model(
logging=True,
device=torch.device("cpu"),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
train_dataloader : torch.utils.data.DataLoader
DataLoader for training data
train_sampler : torch.utils.data.Sampler
Sampler for training data
test_dataloader : torch.utils.data.DataLoader
DataLoader for test data
test_sampler : torch.utils.data.Sampler
Sampler for test data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
normalization_in : callable, optional
Normalization function to apply to inputs, by default None
normalization_out : callable, optional
Normalization function to apply to targets, by default None
augmentation : bool, optional
Whether to apply data augmentation, by default False
nepochs : int, optional
Number of training epochs, by default 20
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
exp_dir : str, optional
Experiment directory for logging, by default None
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start = time.time()
# set AMP type
......
......@@ -39,77 +39,7 @@ from baseline_models import Transformer, UNet, Segformer
from torch_harmonics.examples.models import SphericalFourierNeuralOperator, LocalSphericalNeuralOperator, SphericalTransformer, SphericalUNet, SphericalSegformer
def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_prediction=False, drop_path_rate=0., grid="equiangular"):
"""
Get a registry of baseline models for spherical and planar neural networks.
This function returns a dictionary containing pre-configured model architectures
for various tasks including spherical Fourier neural operators (SFNO), local
spherical neural operators (LSNO), spherical transformers, U-Nets, and Segformers.
Each model is configured with specific hyperparameters optimized for different
computational budgets and performance requirements.
Parameters
----------
img_size : tuple, optional
Input image size as (height, width), by default (128, 256)
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
residual_prediction : bool, optional
Whether to use residual prediction (add input to output), by default False
drop_path_rate : float, optional
Dropout path rate for regularization, by default 0.0
grid : str, optional
Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular"
Returns
----------
dict
Dictionary mapping model names to partial functions that can be called
to instantiate the corresponding model with the specified parameters.
Available models include:
**Spherical Models:**
- sfno_sc2_layers4_e32: Spherical Fourier Neural Operator (small)
- lsno_sc2_layers4_e32: Local Spherical Neural Operator (small)
- s2unet_sc2_layers4_e128: Spherical U-Net (medium)
- s2transformer_sc2_layers4_e128: Spherical Transformer (global attention, medium)
- s2transformer_sc2_layers4_e256: Spherical Transformer (global attention, large)
- s2ntransformer_sc2_layers4_e128: Spherical Transformer (neighborhood attention, medium)
- s2ntransformer_sc2_layers4_e256: Spherical Transformer (neighborhood attention, large)
- s2segformer_sc2_layers4_e128: Spherical Segformer (global attention, medium)
- s2segformer_sc2_layers4_e256: Spherical Segformer (global attention, large)
- s2nsegformer_sc2_layers4_e128: Spherical Segformer (neighborhood attention, medium)
- s2nsegformer_sc2_layers4_e256: Spherical Segformer (neighborhood attention, large)
**Planar Models:**
- transformer_sc2_layers4_e128: Planar Transformer (global attention, medium)
- transformer_sc2_layers4_e256: Planar Transformer (global attention, large)
- ntransformer_sc2_layers4_e128: Planar Transformer (neighborhood attention, medium)
- ntransformer_sc2_layers4_e256: Planar Transformer (neighborhood attention, large)
- segformer_sc2_layers4_e128: Planar Segformer (global attention, medium)
- segformer_sc2_layers4_e256: Planar Segformer (global attention, large)
- nsegformer_sc2_layers4_e128: Planar Segformer (neighborhood attention, medium)
- nsegformer_sc2_layers4_e256: Planar Segformer (neighborhood attention, large)
- vit_sc2_layers4_e128: Vision Transformer variant (medium)
Examples
----------
>>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1)
>>> model = model_registry['sfno_sc2_layers4_e32']()
>>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Notes
----------
- Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim}
- 'sc2' indicates scale factor of 2 (downsampling by 2)
- 'e32', 'e128', 'e256' indicate embedding dimensions
- 'n' prefix indicates neighborhood attention instead of global attention
- All models use GELU activation and instance normalization by default
"""
# prepare dicts containing models and corresponding metrics
model_registry = dict(
sfno_sc2_layers4_e32 = partial(
......
......@@ -68,38 +68,13 @@ import wandb
# helper routine for counting number of paramerters in model
def count_parameters(model):
"""
Count the number of trainable parameters in a model.
Parameters
-----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# convenience function for logging weights and gradients
def log_weights_and_grads(exp_dir, model, iters=1):
"""
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
-----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
log_path = os.path.join(exp_dir, "weights_and_grads")
if not os.path.isdir(log_path):
os.makedirs(log_path, exist_ok=True)
......@@ -116,34 +91,7 @@ def log_weights_and_grads(exp_dir, model, iters=1):
# rolls out the FNO and compares to the classical solver
def validate_model(model, dataloader, loss_fn, metrics_fns, path_root, normalization=None, logging=True, device=torch.device("cpu")):
"""
Validate a model on a dataset and compute metrics.
Parameters
-----------
model : torch.nn.Module
Model to validate
dataloader : torch.utils.data.DataLoader
DataLoader for validation data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving validation outputs
normalization : callable, optional
Normalization function to apply to inputs, by default None
logging : bool, optional
Whether to save validation plots, by default True
device : torch.device, optional
Device to run validation on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics) where losses is a tensor of per-sample losses
and metrics is a dict of per-sample metric values
"""
model.eval()
num_examples = len(dataloader)
......@@ -229,50 +177,7 @@ def train_model(
logging=True,
device=torch.device("cpu"),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
train_dataloader : torch.utils.data.DataLoader
DataLoader for training data
train_sampler : torch.utils.data.Sampler
Sampler for training data
test_dataloader : torch.utils.data.DataLoader
DataLoader for test data
test_sampler : torch.utils.data.Sampler
Sampler for test data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
max_grad_norm : float, optional
Maximum gradient norm for clipping, by default 0.0
normalization : callable, optional
Normalization function to apply to inputs, by default None
augmentation : callable, optional
Augmentation function to apply to inputs, by default None
nepochs : int, optional
Number of training epochs, by default 20
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
exp_dir : str, optional
Experiment directory for logging, by default None
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start = time.time()
# set AMP type
......
......@@ -63,36 +63,11 @@ except:
# helper routine for counting number of paramerters in model
def count_parameters(model):
"""
Count the number of trainable parameters in a model.
Parameters
-----------
model : torch.nn.Module
The model to count parameters for
Returns
-------
int
Total number of trainable parameters
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1):
"""
Helper routine intended for debugging purposes.
Saves model weights and gradients to a file for analysis.
Parameters
-----------
model : torch.nn.Module
Model whose weights and gradients to log
iters : int, optional
Current iteration number, by default 1
"""
root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads")
weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
......@@ -119,40 +94,7 @@ def autoregressive_inference(
nics=50,
device=torch.device("cpu"),
):
"""
Perform autoregressive inference with a trained model and compare to classical solver.
Parameters
-----------
model : torch.nn.Module
Trained model to evaluate
dataset : torch.utils.data.Dataset
Dataset containing solver and normalization parameters
loss_fn : callable
Loss function for evaluation
metrics_fns : dict
Dictionary of metric functions to compute
path_root : str
Root path for saving inference outputs
nsteps : int
Number of solver steps per autoregressive step
autoreg_steps : int, optional
Number of autoregressive steps, by default 10
nskip : int, optional
Skip interval for plotting, by default 1
plot_channel : int, optional
Channel to plot, by default 0
nics : int, optional
Number of initial conditions to test, by default 50
device : torch.device, optional
Device to run inference on, by default torch.device("cpu")
Returns
-------
tuple
(losses, metrics, model_times, solver_times) where losses and metrics are tensors
of per-sample values, and model_times and solver_times are timing information
"""
model.eval()
# make output
......@@ -293,42 +235,7 @@ def train_model(
logging=True,
device=torch.device("cpu"),
):
"""
Train a model with the given parameters.
Parameters
-----------
model : torch.nn.Module
Model to train
dataloader : torch.utils.data.DataLoader
DataLoader for training data
loss_fn : callable
Loss function
metrics_fns : dict
Dictionary of metric functions to compute
optimizer : torch.optim.Optimizer
Optimizer for training
gscaler : torch.cuda.amp.GradScaler
Gradient scaler for mixed precision training
scheduler : torch.optim.lr_scheduler._LRScheduler, optional
Learning rate scheduler, by default None
nepochs : int, optional
Number of training epochs, by default 20
nfuture : int, optional
Number of future steps to predict, by default 0
num_examples : int, optional
Number of examples per epoch, by default 256
num_valid : int, optional
Number of validation examples, by default 8
amp_mode : str, optional
Mixed precision mode ("none", "fp16", "bf16"), by default "none"
log_grads : int, optional
Frequency of gradient logging (0 for no logging), by default 0
logging : bool, optional
Whether to enable logging, by default True
device : torch.device, optional
Device to train on, by default torch.device("cpu")
"""
train_start = time.time()
# set AMP type
......
......@@ -219,19 +219,7 @@ class ResampleS2(nn.Module):
return x
def forward(self, x: torch.Tensor):
"""
Forward pass of the resampling module.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat_in, nlon_in)
Returns
-------
torch.Tensor
Resampled tensor with shape (..., nlat_out, nlon_out)
"""
if self.skip_resampling:
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment