Commit ca46b9d2 authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

cleanup of unnecessary docstrings

parent 9c26a6d8
......@@ -153,9 +153,6 @@ class DistributedResampleS2(nn.Module):
self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor):
......
......@@ -77,27 +77,7 @@ class DistributedRealSHT(nn.Module):
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
"""
Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
"""
super().__init__()
self.nlat = nlat
......@@ -369,22 +349,6 @@ class DistributedRealVectorSHT(nn.Module):
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
"""
super().__init__()
......
......@@ -102,25 +102,6 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
"""
Transpose a tensor along two dimensions.
Parameters
----------
tensor: torch.Tensor
The tensor to transpose
dim0: int
The first dimension to transpose
dim1: int
The second dimension to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
tensor_list: List[torch.Tensor]
The split tensors
"""
# get comm params
comm_size = dist.get_world_size(group=group)
......@@ -198,7 +179,6 @@ class distributed_transpose_polar(torch.autograd.Function):
# we need those additional primitives for distributed matrix multiplications
def _reduce(input_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
......@@ -219,7 +199,6 @@ def _reduce(input_, use_fp32=True, group=None):
def _split(input_, dim_, group=None):
"""Split the tensor along its last dimension and keep the corresponding slice."""
# Bypass the function if we are using only 1 GPU.
comm_size = dist.get_world_size(group=group)
if comm_size == 1:
......@@ -236,7 +215,6 @@ def _split(input_, dim_, group=None):
def _gather(input_, dim_, shapes_, group=None):
"""Gather unevenly split tensors across ranks"""
comm_size = dist.get_world_size(group=group)
......@@ -269,7 +247,6 @@ def _gather(input_, dim_, shapes_, group=None):
def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group and scatter it back."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
......
......@@ -275,19 +275,10 @@ class SphericalLossBase(nn.Module, ABC):
@abstractmethod
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""Abstract method that must be implemented by child classes to compute loss terms.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
Returns:
torch.Tensor: Computed loss term before integration
"""
pass
def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor:
"""Post-integration hook. Commonly used for the roots in Lp norms"""
return loss
def forward(self, prd: torch.Tensor, tar: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
......@@ -309,21 +300,7 @@ class SquaredL2LossS2(SphericalLossBase):
"""
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""
Compute squared L2 loss term.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Squared difference between prediction and target
"""
return torch.square(prd - tar)
......@@ -335,21 +312,7 @@ class L1LossS2(SphericalLossBase):
"""
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""
Compute L1 loss term.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Absolute difference between prediction and target
"""
return torch.abs(prd - tar)
......@@ -361,19 +324,7 @@ class L2LossS2(SquaredL2LossS2):
"""
def _post_integration_hook(self, loss: torch.Tensor) -> torch.Tensor:
"""
Apply square root to get L2 norm.
Parameters
-----------
loss : torch.Tensor
Integrated squared loss
Returns
-------
torch.Tensor
Square root of the loss (L2 norm)
"""
return torch.sqrt(loss)
......@@ -385,18 +336,7 @@ class W11LossS2(SphericalLossBase):
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
"""
Initialize W11 loss.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
"""
super().__init__(nlat=nlat, nlon=nlon, grid=grid)
# Set up grid and domain for FFT
l_phi = 2 * torch.pi # domain size
......@@ -512,21 +452,7 @@ class NormalLossS2(SphericalLossBase):
return normals
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
"""
Compute combined L1 and normal consistency loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Combined loss term
"""
# Handle dimensions for both prediction and target
# Ensure we have at least a batch dimension
if prd.dim() == 2:
......
......@@ -113,7 +113,7 @@ def _get_stats_multiclass(
def _predict_classes(logits: torch.Tensor) -> torch.Tensor:
"""
"""
Convert logits to class predictions using softmax and argmax.
Parameters
......
......@@ -41,46 +41,11 @@ from torch_harmonics import InverseRealSHT
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
"""
Initialize tensor with truncated normal distribution without gradients.
This is a helper function for trunc_normal_ that performs the actual initialization
without requiring gradients to be tracked.
Parameters
-----------
tensor : torch.Tensor
Tensor to initialize
mean : float
Mean of the normal distribution
std : float
Standard deviation of the normal distribution
a : float
Lower bound for truncation
b : float
Upper bound for truncation
Returns
-------
torch.Tensor
Initialized tensor
"""
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
"""
Compute standard normal cumulative distribution function.
Parameters
-----------
x : float
Input value
Returns
-------
float
CDF value
"""
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
......
......@@ -80,34 +80,13 @@ class Stanford2D3DSDownloader:
"""
def __init__(self, base_url: str = DEFAULT_BASE_URL, local_dir: str = "data"):
"""
Initialize the Stanford 2D3DS dataset downloader.
Parameters
-----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
"""
self.base_url = base_url
self.local_dir = local_dir
os.makedirs(self.local_dir, exist_ok=True)
def _download_file(self, filename):
"""
Download a single file with progress bar and resume capability.
Parameters
-----------
filename : str
Name of the file to download
Returns
-------
str
Local path to the downloaded file
"""
import requests
from tqdm import tqdm
......@@ -141,19 +120,7 @@ class Stanford2D3DSDownloader:
return local_path
def _extract_tar(self, tar_path):
"""
Extract a tar file and return the extracted directory name.
Parameters
-----------
tar_path : str
Path to the tar file to extract
Returns
-------
str
Name of the extracted directory
"""
import tarfile
with tarfile.open(tar_path) as tar:
......@@ -194,23 +161,7 @@ class Stanford2D3DSDownloader:
return data_folders, class_labels
def _rgb_to_id(self, img, class_labels_map, class_labels_indices):
"""
Convert RGB image to class ID using color mapping.
Parameters
-----------
img : numpy.ndarray
RGB image array
class_labels_map : list
Mapping from color values to class labels
class_labels_indices : list
List of class label indices
Returns
-------
numpy.ndarray
Class ID array
"""
# Convert to int32 first to avoid overflow
r = img[..., 0].astype(np.int32)
g = img[..., 1].astype(np.int32)
......
......@@ -35,35 +35,6 @@ from .sht import InverseRealSHT
class GaussianRandomFieldS2(torch.nn.Module):
def __init__(self, nlat, alpha=2.0, tau=3.0, sigma=None, radius=1.0, grid="equiangular", dtype=torch.float32):
super().__init__()
r"""
A mean-zero Gaussian Random Field on the sphere with Matern covariance:
C = sigma^2 (-Lap + tau^2 I)^(-alpha).
Lap is the Laplacian on the sphere, I the identity operator,
and sigma, tau, alpha are scalar parameters.
Note: C is trace-class on L^2 if and only if alpha > 1.
Parameters
----------
nlat : int
Number of latitudinal modes;
longitudinal modes are 2*nlat.
alpha : float, default is 2
Regularity parameter. Larger means smoother.
tau : float, default is 3
Lenght-scale parameter. Larger means more scales.
sigma : float, default is None
Scale parameter. Larger means bigger.
If None, sigma = tau**(0.5*(2*alpha - 2.0)).
radius : float, default is 1
Radius of the sphere.
grid : string, default is "equiangular"
Grid type. Currently supports "equiangular" and
"legendre-gauss".
dtype : torch.dtype, default is torch.float32
Numerical type for the calculations.
"""
#Number of latitudinal modes.
self.nlat = nlat
......
......@@ -137,25 +137,11 @@ class ResampleS2(nn.Module):
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor):
"""
Interpolate the input tensor along the longitude dimension.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Interpolated tensor along longitude dimension
"""
# do the interpolation in precision of x
lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear":
......@@ -192,19 +178,7 @@ class ResampleS2(nn.Module):
return x
def _upscale_latitudes(self, x: torch.Tensor):
"""
Interpolate the input tensor along the latitude dimension.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Interpolated tensor along latitude dimension
"""
# do the interpolation in precision of x
lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear":
......
......@@ -72,31 +72,7 @@ class RealSHT(nn.Module):
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
r"""
Initializes the SHT Layer, precomputing the necessary quadrature weights
Parameters
-----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
"""
super().__init__()
......@@ -314,31 +290,7 @@ class RealVectorSHT(nn.Module):
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
r"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters
-----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
"""
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment