Commit 901e8635 authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

removed docstrings from forward passes

parent 6373534a
......@@ -103,19 +103,6 @@ class OverlapPatchMerging(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""
Forward pass through the OverlapPatchMerging layer.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after patch merging
"""
x = self.conv(x)
# permute
......@@ -204,19 +191,7 @@ class MixFFN(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the MixFFN module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after processing
"""
residual = x
# norm
......
......@@ -693,19 +693,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass
Parameters
-----------
x: torch.Tensor
Input tensor
Returns
-------
out: torch.Tensor
Output tensor
"""
# extract shape
B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W)
......
......@@ -248,19 +248,6 @@ class DistributedResampleS2(nn.Module):
return x
def forward(self, x: torch.Tensor):
"""
Forward pass for distributed resampling.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Resampled tensor with shape (batch, channels, nlat_out, nlon_out)
"""
if self.skip_resampling:
return x
......
......@@ -150,23 +150,7 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x, dims, dim1_split_sizes):
r"""
Forward pass for distributed azimuthal transpose.
Parameters
----------
x: torch.Tensor
The tensor to transpose
dims: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
"""
# WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
x = torch.cat(xlist, dim=dims[1])
......@@ -205,23 +189,7 @@ class distributed_transpose_polar(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, x, dim, dim1_split_sizes):
r"""
Forward pass for distributed polar transpose.
Parameters
----------
x: torch.Tensor
The tensor to transpose
dim: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
"""
# WAR for a potential contig check torch bug for channels last contig tensors
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
x = torch.cat(xlist, dim=dim[1])
......@@ -363,19 +331,7 @@ class _CopyToPolarRegion(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_):
r"""
Forward pass for copying to polar region.
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns
-------
input_: torch.Tensor
The tensor to copy
"""
return input_
@staticmethod
......@@ -409,19 +365,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_):
r"""
Forward pass for copying to azimuth region.
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns
-------
input_: torch.Tensor
The tensor to copy
"""
return input_
@staticmethod
......
......@@ -115,21 +115,7 @@ class DiceLossS2(nn.Module):
self.register_buffer("weight", weight.unsqueeze(0))
def forward(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Dice loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
Returns
-------
torch.Tensor
Dice loss value
"""
prd = nn.functional.softmax(prd, dim=1)
# mask values
......@@ -205,21 +191,6 @@ class CrossEntropyLossS2(nn.Module):
self.register_buffer("quad_weights", q)
def forward(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the cross-entropy loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
Returns
-------
torch.Tensor
Cross-entropy loss value
"""
# compute log softmax
logits = nn.functional.log_softmax(prd, dim=1)
......@@ -266,25 +237,6 @@ class FocalLossS2(nn.Module):
self.register_buffer("quad_weights", q)
def forward(self, prd: torch.Tensor, tar: torch.Tensor, alpha: float = 0.25, gamma: float = 2):
"""
Forward pass of the focal loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor with shape (batch, classes, nlat, nlon)
tar : torch.Tensor
Target tensor with shape (batch, nlat, nlon)
alpha : float, optional
Alpha parameter for focal loss, by default 0.25
gamma : float, optional
Gamma parameter for focal loss, by default 2
Returns
-------
torch.Tensor
Focal loss value
"""
# compute logits
logits = nn.functional.log_softmax(prd, dim=1)
......@@ -339,16 +291,7 @@ class SphericalLossBase(nn.Module, ABC):
return loss
def forward(self, prd: torch.Tensor, tar: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Common forward pass that handles masking and reduction.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None.
Returns:
torch.Tensor: Final loss value
"""
loss_term = self._compute_loss_term(prd, tar)
# Integrate over the sphere for each item in the batch
loss = self._integrate_sphere(loss_term, mask)
......
......@@ -169,21 +169,7 @@ class BaseMetricS2(nn.Module):
self.register_buffer("weight", weight.unsqueeze(0))
def _forward(self, pred: torch.Tensor, truth: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute base statistics (TP, FP, FN, TN) for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Tuple containing (tp, fp, fn, tn) statistics
"""
# convert logits to class predictions
pred_class = _predict_classes(pred)
......@@ -240,21 +226,7 @@ class IntersectionOverUnionS2(BaseMetricS2):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
def forward(self, pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
"""
Compute IoU score for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
torch.Tensor
IoU score
"""
tp, fp, fn, tn = self._forward(pred, truth)
# compute score
......@@ -300,21 +272,7 @@ class AccuracyS2(BaseMetricS2):
super().__init__(nlat, nlon, grid, weight, ignore_index, mode)
def forward(self, pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
"""
Compute accuracy score for the given predictions and ground truth.
Parameters
-----------
pred : torch.Tensor
Predicted logits
truth : torch.Tensor
Ground truth labels
Returns
-------
torch.Tensor
Accuracy score
"""
tp, fp, fn, tn = self._forward(pred, truth)
# compute score
......
......@@ -191,19 +191,7 @@ class DropPath(nn.Module):
self.drop_prob = drop_prob
def forward(self, x):
"""
Forward pass with drop path regularization.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor with potential path dropping
"""
return drop_path(x, self.drop_prob, self.training)
......@@ -238,19 +226,7 @@ class PatchEmbed(nn.Module):
self.proj.bias.is_shared_mp = ["spatial"]
def forward(self, x):
"""
Forward pass of patch embedding.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, channels, height, width)
Returns
-------
torch.Tensor
Patch embeddings of shape (batch_size, embed_dim, num_patches)
"""
# gather input
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
......@@ -319,35 +295,11 @@ class MLP(nn.Module):
@torch.jit.ignore
def checkpoint_forward(self, x):
"""
Forward pass with gradient checkpointing.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor
"""
return checkpoint(self.fwd, x)
def forward(self, x):
"""
Forward pass of the MLP.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor
"""
if self.checkpointing:
return self.checkpoint_forward(x)
else:
......@@ -382,19 +334,7 @@ class RealFFT2(nn.Module):
self.mmax = mmax or self.nlon // 2 + 1
def forward(self, x):
"""
Forward pass: compute real FFT2D.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
FFT coefficients
"""
y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
return y
......@@ -428,19 +368,7 @@ class InverseRealFFT2(nn.Module):
self.mmax = mmax or self.nlon // 2 + 1
def forward(self, x):
"""
Forward pass: compute inverse real FFT2D.
Parameters
----------
x : torch.Tensor
Input FFT coefficients
Returns
-------
torch.Tensor
Reconstructed spatial signal
"""
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
......@@ -476,19 +404,7 @@ class LayerNorm(nn.Module):
self.norm = nn.LayerNorm(normalized_shape=in_channels, eps=1e-6, elementwise_affine=elementwise_affine, bias=bias, device=device, dtype=dtype)
def forward(self, x):
"""
Forward pass with channel dimension handling.
Parameters
----------
x : torch.Tensor
Input tensor with channel dimension at -3
Returns
-------
torch.Tensor
Normalized tensor with same shape as input
"""
return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim)
......@@ -556,19 +472,7 @@ class SpectralConvS2(nn.Module):
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x):
"""
Forward pass of spectral convolution.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
tuple
Tuple containing (output, residual) tensors
"""
dtype = x.dtype
x = x.float()
residual = x
......@@ -614,19 +518,7 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
self.num_chans = num_chans
def forward(self, x: torch.Tensor):
"""
Forward pass: add position embeddings to input.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Input tensor with position embeddings added
"""
return x + self.position_embeddings
......
......@@ -110,19 +110,7 @@ class DiscreteContinuousEncoder(nn.Module):
)
def forward(self, x):
"""
Forward pass of the discrete-continuous encoder.
Parameters
----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Encoded tensor with reduced spatial resolution
"""
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -205,19 +193,7 @@ class DiscreteContinuousDecoder(nn.Module):
)
def forward(self, x):
"""
Forward pass of the discrete-continuous decoder.
Parameters
----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Decoded tensor with restored spatial resolution
"""
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -628,19 +604,7 @@ class LocalSphericalNeuralOperator(nn.Module):
return x
def forward(self, x):
"""
Forward pass through the complete LSNO model.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if self.residual_prediction:
residual = x
......
......@@ -130,19 +130,7 @@ class OverlapPatchMerging(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""
Forward pass of the overlap patch merging module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Merged patches with layer normalization
"""
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -259,19 +247,6 @@ class MixFFN(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Mix FFN module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after Mix FFN processing
"""
residual = x
# norm
......@@ -382,19 +357,7 @@ class AttentionWrapper(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the attention wrapper.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after attention processing
"""
residual = x
if self.norm is not None:
x = x.permute(0, 2, 3, 1)
......@@ -548,19 +511,6 @@ class TransformerBlock(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the transformer block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after transformer block processing
"""
x = self.fwd(x)
# apply norm
......@@ -664,19 +614,7 @@ class Upsampling(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the upsampling module.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Upsampled tensor
"""
x = self.upsample(self.mlp(x))
return x
......@@ -871,19 +809,7 @@ class SphericalSegformer(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""
Forward pass through the complete spherical segformer model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder:
features = []
feat = x
......
......@@ -113,19 +113,7 @@ class DiscreteContinuousEncoder(nn.Module):
)
def forward(self, x):
"""
Forward pass of the discrete-continuous encoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Encoded tensor with reduced spatial resolution
"""
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -209,19 +197,7 @@ class DiscreteContinuousDecoder(nn.Module):
)
def forward(self, x):
"""
Forward pass of the discrete-continuous decoder.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Returns
-------
torch.Tensor
Decoded tensor with restored spatial resolution
"""
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
......@@ -593,19 +569,7 @@ class SphericalTransformer(nn.Module):
return x
def forward(self, x):
"""
Forward pass through the complete spherical transformer model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if self.residual_prediction:
residual = x
......
......@@ -208,19 +208,7 @@ class DownsamplingBlock(nn.Module):
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the downsampling block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Downsampled tensor
"""
# skip connection
residual = x
if hasattr(self, "transform_skip"):
......@@ -614,19 +602,7 @@ class SphericalUNet(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""
Forward pass through the complete spherical U-Net model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder:
features = []
feat = x
......
......@@ -153,19 +153,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
raise ValueError(f"Unknown skip connection type {outer_skip}")
def forward(self, x):
"""
Forward pass through the SFNO block.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
----------
torch.Tensor
Output tensor after processing through the block
"""
x, residual = self.global_conv(x)
x = self.norm(x)
......@@ -415,19 +403,7 @@ class SphericalFourierNeuralOperator(nn.Module):
return {"pos_embed.pos_embed"}
def forward_features(self, x):
"""
Forward pass through the feature extraction layers.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
----------
torch.Tensor
Features after processing through the network
"""
x = self.pos_drop(x)
for blk in self.blocks:
......@@ -436,19 +412,7 @@ class SphericalFourierNeuralOperator(nn.Module):
return x
def forward(self, x):
"""
Forward pass through the complete SFNO model.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
----------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if self.residual_prediction:
residual = x
......
......@@ -94,24 +94,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
def forward(self, N, xi=None):
r"""
Sample random functions from a spherical GRF.
Parameters
----------
N : int
Number of functions to sample.
xi : torch.Tensor, default is None
Noise is a complex tensor of size (N, nlat, nlat+1).
If None, new Gaussian noise is sampled.
If xi is provided, N is ignored.
Output
-------
u : torch.Tensor
N random samples from the GRF returned as a
tensor of size (N, nlat, 2*nlat) on a equiangular grid.
"""
#Sample Gaussian noise.
if xi is None:
xi = self.gaussian_noise.sample(torch.Size((N, self.nlat, self.nlat + 1, 2))).squeeze()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment